from __future__ import annotations
import copy
import logging
import warnings
import polars as pl
import polars.selectors as cs
from corr_vars import logger, __
import corr_vars.sources as corr
import corr_vars.utils as utils
from collections.abc import Collection, Iterable, Mapping
from corr_vars.definitions.exceptions import (
ObsLevelNotSupportedError,
VariableNotFoundError,
)
from corr_vars.definitions.constants import DYN_COLUMNS, COL_ORDER, PRIMARY_KEYS
from corr_vars.utils.time import TimeAnchor
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from corr_vars.core.cohort import Cohort, ObsLevel
from corr_vars.definitions.typing import (
VariableLoaderProtocol,
VariableProtocol,
ObsLevelName,
)
from corr_vars.utils.time import TimeWindow
# ---------------------------------------------------------------------------
# MultiSourceVariable with pre and post-processors
# ---------------------------------------------------------------------------
def validate_primary_key(var_df: pl.DataFrame, primary_key: str) -> None:
"""Validate that the primary key is known and present in the variable DataFrame.
Args:
var_df (pl.DataFrame): The variable data to validate.
primary_key (str): The primary key column name.
Raises:
ValueError: If the primary key is not in `PRIMARY_KEYS` or absent from `var_df`.
"""
if primary_key not in PRIMARY_KEYS:
raise ValueError(f"Primary key {primary_key} must be one of {PRIMARY_KEYS}.")
if primary_key not in var_df.columns:
raise ValueError(f"Primary key {primary_key} not in data columns.")
def add_relative_times(
reference_df: pl.DataFrame,
var_df: pl.DataFrame,
*,
reference_col: str,
join_col: str,
suffix: str = "_relative",
include_time_cols: Collection[str] | None = None,
) -> pl.DataFrame:
"""Add relative time columns to dynamic variable data.
For each temporal column in `var_df` (optionally restricted to
`include_time_cols`), computes the signed difference from `reference_col`
in seconds and appends the result as a new column with `suffix` appended to
the original column name. The reference column is dropped from the result.
Args:
reference_df (pl.DataFrame): DataFrame containing the reference time column.
var_df (pl.DataFrame): DataFrame containing the variable data.
reference_col (str): Column in `reference_df` holding the reference timestamp.
join_col (str): Column used to join `reference_df` onto `var_df`.
suffix (str): Suffix appended to each relative-time column name. Defaults to
`"_relative"`.
include_time_cols (Collection[str] | None): Restrict relative-time computation
to these temporal columns. All temporal columns are used when `None`.
Returns:
pl.DataFrame: `var_df` with relative time columns added and `reference_col` dropped.
"""
var_df = var_df.join(
reference_df.select(join_col, reference_col),
on=join_col,
how="left",
)
selector = cs.temporal()
if include_time_cols is not None:
selector &= cs.by_name(include_time_cols, require_all=False)
var_df = var_df.with_columns(
selector.pipe(
utils.time_difference,
reference_col=reference_col,
unit="s",
total=True,
).name.suffix(suffix)
)
# Drop the reference time column
var_df = var_df.drop(reference_col)
return var_df
def unify_and_order_columns(var_df: pl.DataFrame) -> pl.DataFrame:
"""Enforce `COL_ORDER` and fill missing dynamic columns with null.
Columns outside `COL_ORDER` that are not in `DYN_COLUMNS` trigger a
deprecation warning and are appended after the ordered columns until the
`attributes` migration is complete.
Args:
var_df (pl.DataFrame): The variable DataFrame to standardize.
Returns:
pl.DataFrame: `var_df` with columns reordered and missing dynamic columns added as null.
"""
# Add missing dynamic columns (no keys) with null values
missing_cols = [col for col in DYN_COLUMNS if col not in var_df.columns]
if missing_cols:
var_df = var_df.with_columns(pl.lit(None).alias(col) for col in missing_cols)
# Order columns according to COL_ORDER, then add remaining columns
data_cols = set(var_df.columns)
ordered_cols = [col for col in COL_ORDER if col in data_cols]
remaining_cols = [col for col in data_cols if col not in COL_ORDER]
if len(remaining_cols) > 0:
warnings.warn(
"🚨 DEPRECATION ALERT 🚨\n"
"Unknown columns found.\n"
"- Please update the variable definition to move these to attributes.\n"
"- Unknown columns are no longer supported and might be deprecated soon!\n"
"- Columns: " + ", ".join(remaining_cols),
UserWarning,
)
var_df = var_df.select(ordered_cols + remaining_cols)
# Uncomment after user warning
# var_df = var_df.select(ordered_cols)
return var_df
class MultiSourceVariable:
"""Holds one :class:`~corr_vars.definitions.typing.VariableProtocol` per source and
extracts, post-validates, and post-processes each before concatenating the results.
Args:
var_name (str): Shared variable name across all sources.
variables (Mapping[str, VariableProtocol]): Mapping from source name to the
source-specific variable implementation.
"""
variables: Mapping[str, VariableProtocol]
data: pl.DataFrame | None
_var_name: str
_dynamic: bool
_time_window: TimeWindow
def __init__(self, variables: Mapping[str, VariableProtocol]) -> None:
self.variables = variables
self.data = None
self._pre_validate()
def _pre_validate(self) -> None:
if not self.variables:
raise ValueError("No variables to concatenate")
# Assign dynamic, tmin and tmax after veryfing variables is not empty
first_variable = next(iter(self.variables.values()))
self._var_name = first_variable.var_name
self._dynamic = first_variable.dynamic
self._time_window = first_variable.time_window
for var in self.variables.values():
assert (
var.var_name == self._var_name
), f"Variable name mismatch: {var.var_name} != {self._var_name}"
assert (
var.dynamic == self._dynamic
), f"Dynamic status mismatch for {self._var_name}: {var.dynamic} != {self._dynamic}"
assert (
var.time_window == self._time_window
), f"Timebound mismatch for {self._var_name}: {var.time_window} != { self._time_window}"
def extract(self, cohort: Cohort, with_source_column: bool = False) -> pl.DataFrame:
var_data: list[pl.DataFrame] = []
for src, var in self.variables.items():
# Extract and validate data
extracted = var.extract(cohort)
self._post_validate(cohort, extracted)
# Transform Data
extracted = self._post_process(cohort, extracted)
if with_source_column:
extracted = extracted.insert_column(0, pl.lit(src).alias("data_source"))
# Add to list
var_data.append(extracted)
logger.info(
__(
"SUCCESS: Extracted {var_name} from {src}",
var_name=self._var_name,
src=src,
)
)
# Concatenate data
combined_data = pl.concat(var_data, how="diagonal_relaxed")
self.data = combined_data
return combined_data
def _post_process(self, cohort: Cohort, var_data: pl.DataFrame) -> pl.DataFrame:
if not self.dynamic:
return var_data
processed = var_data
include_time_cols = ("recordtime", "recordtime_end")
suffix = "_relative"
if not frozenset(f"{col}{suffix}" for col in include_time_cols).issubset(
frozenset(var_data.columns)
):
processed = add_relative_times(
cohort.obs,
processed,
reference_col=cohort.t_min,
join_col=cohort.primary_key,
suffix=suffix,
include_time_cols=include_time_cols,
)
processed = unify_and_order_columns(processed)
return processed
def _post_validate(self, cohort: Cohort, var_data: pl.DataFrame) -> None:
validate_primary_key(var_data, cohort.primary_key)
@property
def var_name(self) -> str:
return self._var_name
@var_name.setter
def var_name(self, _: str) -> None:
raise AttributeError("Cannot set var_name")
@property
def dynamic(self) -> bool:
return self._dynamic
@dynamic.setter
def dynamic(self, _: bool) -> bool:
raise AttributeError("Cannot set dynamic")
@property
def time_window(self) -> TimeWindow:
return self._time_window
@time_window.setter
def time_window(self, _: TimeWindow) -> None:
raise ValueError("Cannot set time_window")
@property
def extracted(self) -> bool:
return self.data is not None
def __repr__(self) -> str:
return (
f"MultiSourceVariable(var_name='{self._var_name}', "
f"dynamic={self._dynamic}, "
f"time_window={self._time_window}, "
f"sources={list(self.variables.keys())}, "
f"extracted={self.extracted})"
)
def __str__(self) -> str:
return self.__repr__()
# ---------------------------------------------------------------------------
# load_raw_variable_configs -> dict[str, dict[str, object]]
# ---------------------------------------------------------------------------
def load_raw_variable_configs(
include_sources: Iterable[str] | None = None,
) -> dict[str, dict[str, dict[str, object]]]:
"""Collect the `vars.json` entry from every source.
Args:
include_sources (Iterable[str] | None): Restrict the lookup to these sources.
All discovered sources are searched when `None`.
Returns:
dict[str, dict[str, dict[str, object]]]: Mapping from source name to its full raw variable config
dict.
"""
var_dict: dict[str, dict[str, dict[str, Any]]] = {}
for src, src_mapping_module in corr.get_src_module_mapping(
include_sources=include_sources, submodule="mapping"
).items():
src_VARS = getattr(src_mapping_module, "VARS")
src_VARS_variables = src_VARS["variables"]
var_dict[src] = src_VARS_variables
return var_dict
# ---------------------------------------------------------------------------
# _load_variable_configs -> dict[str, dict[str, object]]
# ---------------------------------------------------------------------------
def _collect_variable_configs_by_source(
var_name: str,
include_sources: Iterable[str] | None = None,
) -> dict[str, dict[str, object]]:
"""Collect the `vars.json` entry for `var_name` from every source.
Also resolves the `py` callable from `mapping/variables.py` and injects it
into the returned config dict under the `"py"` key when a matching function exists.
Args:
var_name (str): Name of the variable to look up.
include_sources (Iterable[str] | None): Restrict the lookup to these sources.
All discovered sources are searched when `None`.
Returns:
dict[str, dict[str, object]]: Mapping from source name to its variable config
dict (with `"py"` injected where applicable).
"""
var_dict: dict[str, dict[str, Any]] = {}
for src, src_mapping_module in corr.get_src_module_mapping(
include_sources=include_sources, submodule="mapping"
).items():
src_VARS = getattr(src_mapping_module, "VARS")
src_VARS_variables = src_VARS["variables"]
src_VARCODE = getattr(src_mapping_module, "variables")
if var_name not in src_VARS_variables:
continue
src_vardict = src_VARS_variables[var_name].copy()
if hasattr(src_VARCODE, var_name):
src_vardict["py"] = getattr(src_VARCODE, var_name)
var_dict[src] = src_vardict
return var_dict
def _load_variable_configs(
var_name: str,
cohort: Cohort,
include_sources: Iterable[str] | None = None,
) -> dict[str, dict[str, object]]:
"""Load and merge variable configurations for all requested sources.
Retrieves the raw `vars.json` config for `var_name` via
:func:`_collect_variable_configs_by_source`, then overlays any project-local
overrides from :meth:`~corr_vars.core.cohort.Cohort.get_variable_definition`.
Args:
var_name (str): Name of the variable to load configuration for.
cohort (Cohort): Cohort whose project overrides are applied.
include_sources (Iterable[str] | None): Sources to consider.
All sources are used when `None`.
Returns:
dict[str, dict[str, object]]: Merged per-source variable configuration mapping.
"""
# Get variable configuration for each source defined in cohort
variable_configs_by_source = _collect_variable_configs_by_source(
var_name=var_name,
include_sources=include_sources,
)
# Apply project specific overrides
override = cohort.get_variable_definition(var_name)
variable_configs_by_source = utils.deep_merge(
base=variable_configs_by_source, override=override
)
return variable_configs_by_source
# ---------------------------------------------------------------------------
# load_variable -> MultiSourceVariable
# ---------------------------------------------------------------------------
def _transfer_time_window_from_variable_config(
variable_configs_by_source: dict[str, dict[str, Any]],
time_window: TimeWindow,
) -> tuple[dict[str, dict[str, object]], TimeWindow]:
"""Extract `tmin`/`tmax` overrides from source configs into the time window.
If any source config contains a `"tmin"` or `"tmax"` key, those values are
popped from the config and written to a copy of `time_window`, overriding whatever
was passed in. A warning is logged for each override.
Args:
variable_configs_by_source (dict[str, dict[str, Any]]): Per-source config dicts.
time_window (TimeWindow): The original time window.
Returns:
tuple[dict[str, dict[str, object]], TimeWindow]: Deep copies of the config dict
and time window, updated with any embedded overrides.
"""
time_window_cp = copy.copy(time_window)
variable_configs_by_source_cp = copy.deepcopy(variable_configs_by_source)
for _, var_config in variable_configs_by_source_cp.items():
if "tmin" in var_config:
logger.warning(
__(
"For this variable tmin is already frozen to {tmin}. Any other tmin will be ignored.",
tmin=var_config["tmin"],
)
)
tmin = var_config.pop("tmin")
time_window_cp.tmin = TimeAnchor(tmin)
if "tmax" in var_config:
logger.warning(
__(
"For this variable tmax is already frozen to {tmax}. Any other tmax will be ignored.",
tmax=var_config["tmax"],
)
)
tmax = var_config.pop("tmax")
time_window_cp.tmax = TimeAnchor(tmax)
return variable_configs_by_source_cp, time_window_cp
def _ensure_compatibility(
var_name: str,
variable_configs_by_source: dict[str, dict[str, Any]],
obs_level: ObsLevel,
) -> None:
"""Raise if `var_name` is incompatible with the cohort's observation level.
Pops `"compatible_with"` from each source config dict (mutates in-place) and
raises :exc:`~corr_vars.definitions.exceptions.ObsLevelNotSupportedError` if the
cohort's obs level is not listed.
Args:
var_name (str): Variable name (used in the error message).
variable_configs_by_source (dict[str, dict[str, Any]]): Per-source config dicts.
obs_level (ObsLevel): The cohort's current observation level.
Raises:
ObsLevelNotSupportedError: If `obs_level` is not in `compatible_with`.
"""
for _, var_config in variable_configs_by_source.items():
if "compatible_with" in var_config:
compatible_obs_level: list[ObsLevelName] = var_config.pop("compatible_with")
if obs_level.lower_name not in compatible_obs_level:
raise ObsLevelNotSupportedError(
f"Obs level {obs_level.lower_name} is not compatible with {var_name}"
)
def _log_variable_loader_args_kwargs(
var_name: str,
time_window: TimeWindow,
var_config: dict[str, object],
) -> None:
"""Log the arguments that will be passed to a source's `VariableLoader`.
Args:
var_name (str): Variable name being created.
time_window (TimeWindow): Time window used for extraction.
var_config (dict[str, object]): Keyword arguments forwarded to the loader.
"""
logger.info(__("Creating variable {var_name} with:", var_name=var_name))
prefix = "└──" if not var_config else "├──"
logger.info(
__(
"{prefix} time_window: {time_window}",
prefix=prefix,
time_window=time_window,
)
)
if var_config:
logger.info("└── kwargs:")
utils.log_dict(logger=logger, dictionary=var_config, indent=4, json_indent=2)
def _load_variables_from_config(
var_name: str,
variable_configs_by_source: dict[str, dict[str, Any]],
cohort: Cohort,
time_window: TimeWindow,
) -> dict[str, VariableProtocol]:
"""Instantiate a source-specific variable for each entry in `variable_configs_by_source`.
For each source, retrieves the source's `VariableLoader` factory, calls it with
`var_name`, `time_window`, and the source's config dict as keyword arguments,
and stores the resulting :class:`~corr_vars.definitions.typing.VariableProtocol` in
the output mapping.
Args:
var_name (str): Name of the variable to instantiate.
variable_configs_by_source (dict[str, dict[str, Any]]): Per-source config dicts
whose keys determine which sources are visited.
cohort (Cohort): The cohort (used to skip sources not in `cohort.sources`).
time_window (TimeWindow): Time window forwarded to every `VariableLoader`.
Returns:
dict[str, VariableProtocol]: Mapping from source name to the instantiated variable.
"""
src_var_mapping: dict[str, VariableProtocol] = {}
# Load from other sources
for src, src_module in corr.get_src_module_mapping(
include_sources=variable_configs_by_source.keys()
).items():
if src not in cohort.sources:
logger.info(
__("Source {src} not requested in cohort. Skipping...", src=src)
)
# Load kwargs
var_config = variable_configs_by_source[src]
_log_variable_loader_args_kwargs(
var_name=var_name, time_window=time_window, var_config=var_config
)
# Get VariableLoaderProtocol (adapter) to load VariableProtocol
SrcVariableLoader: VariableLoaderProtocol = getattr(
src_module, "VariableLoader"
)
SrcVariable: VariableProtocol = SrcVariableLoader.__call__(
var_name=var_name, time_window=time_window, **var_config
)
src_var_mapping.setdefault(src, SrcVariable)
logger.info(
__(
"Variable {var_name} found in source {src}",
var_name=var_name,
src=src,
)
)
return src_var_mapping
[docs]
def load_variable(
var_name: str,
cohort: Cohort,
time_window: TimeWindow,
include_sources: Iterable[str] | None = None,
) -> MultiSourceVariable:
"""Load `var_name` from all configured sources and wrap it in a :class:`MultiSourceVariable`.
Orchestrates config loading, project-override merging, time-window transfer,
compatibility checks, and per-source variable instantiation.
Args:
var_name (str): Name of the variable to load.
cohort (Cohort): Cohort that supplies source configs and project overrides.
time_window (TimeWindow): Time window for variable extraction.
include_sources (Iterable[str] | None): Restrict to these sources.
All cohort sources are used when `None`.
Returns:
MultiSourceVariable: Variable ready for extraction.
Raises:
VariableNotFoundError: If `var_name` is not found in any source config, or if
no source variable could be instantiated.
"""
# Get lookup sources
sources = include_sources or cohort.sources.keys()
# Get variable configuration for each source defined in cohort
variable_configs_by_source = _load_variable_configs(
var_name=var_name, cohort=cohort, include_sources=sources
)
# Raise error if no source was found
if not variable_configs_by_source:
raise VariableNotFoundError(f"Variable {var_name} not found in JSON")
# Debug logging
logger.debug(__("Loaded config for {var_name}:", var_name=var_name))
utils.log_dict(
logger, variable_configs_by_source, level=logging.DEBUG, json_indent=2
)
# Transfers "tmin", "tmax" from dict to time_window (pops from config; overrides time_window)
(
variable_configs_by_source,
time_window,
) = _transfer_time_window_from_variable_config(
variable_configs_by_source, time_window
)
# STOP if not compatible with obs level
_ensure_compatibility(var_name, variable_configs_by_source, cohort.obs_level)
# Get variable for each variable configuration by source
variable_by_source = _load_variables_from_config(
var_name=var_name,
variable_configs_by_source=variable_configs_by_source,
cohort=cohort,
time_window=time_window,
)
# Raise error if no variable was found
if not variable_by_source:
raise VariableNotFoundError(
f"Variable {var_name} not found in any source (looked in {sources})"
)
return MultiSourceVariable(variable_by_source)
# ---------------------------------------------------------------------------
# load_default_variables -> list[MultiSourceVariable]
# ---------------------------------------------------------------------------
def _collect_default_variable_list_by_source(
obs_level: ObsLevel,
include_sources: Iterable[str] | None = None,
) -> dict[str, list[str]]:
"""Collect the default variable list from each source's `vars.json`.
Reads `corr_defaults.default_vars` from each source's `VARS` dict, combining
the `"global"` list with the obs-level-specific list.
Args:
obs_level (ObsLevel): Observation level used to select level-specific defaults.
include_sources (Iterable[str] | None): Restrict to these sources.
All discovered sources are checked when `None`.
Returns:
dict[str, list[str]]: Mapping from source name to its list of default variable names.
"""
src_dict: dict[str, list[str]] = {}
for src, src_mapping_module in corr.get_src_module_mapping(
include_sources=include_sources, submodule="mapping"
).items():
src_VARS = getattr(src_mapping_module, "VARS")
if not (
"corr_defaults" in src_VARS and "default_vars" in src_VARS["corr_defaults"]
):
continue
default_vars: dict[str, list[str]] = src_VARS["corr_defaults"]["default_vars"]
global_defaults = default_vars.get("global", [])
obs_level_defaults = default_vars.get(obs_level.lower_name, [])
all_defaults = global_defaults + obs_level_defaults
src_dict.setdefault(src, all_defaults)
return src_dict
def load_default_variables(
cohort: Cohort, time_window: TimeWindow
) -> list[MultiSourceVariable]:
"""Load the default variables for the cohort's observation level.
Aggregates the default variable lists from all cohort sources, then loads each
variable restricted to the sources that declared it.
Args:
cohort (Cohort): Cohort whose observation level and sources determine which
defaults are loaded.
time_window (TimeWindow): Time window forwarded to every variable loader.
Returns:
list[MultiSourceVariable]: One :class:`MultiSourceVariable` per unique default
variable name, ready for extraction.
"""
# Get default variables per source
source_default_variables_mapping = _collect_default_variable_list_by_source(
obs_level=cohort.obs_level, include_sources=cohort.sources.keys()
)
# Invert mapping
variable_source_mapping: dict[str, list[str]] = {}
for src, default_variables in source_default_variables_mapping.items():
for var_name in default_variables:
variable_source_mapping.setdefault(var_name, [])
variable_source_mapping[var_name].append(src)
# Load variable for included sources only
variables: list[MultiSourceVariable] = []
for var_name, include_sources in variable_source_mapping.items():
# TODO: Loading only included sources can be tricky,
# if the variable is loaded in other sources later
variable = load_variable(
var_name=var_name,
cohort=cohort,
time_window=time_window,
include_sources=include_sources,
)
variables.append(variable)
return variables