Source code for corr_vars.sources.var_loader

from __future__ import annotations
import copy
import logging
import warnings

import polars as pl
import polars.selectors as cs

from corr_vars import logger, __
import corr_vars.sources as corr
import corr_vars.utils as utils

from collections.abc import Collection, Iterable, Mapping
from corr_vars.definitions.exceptions import (
    ObsLevelNotSupportedError,
    VariableNotFoundError,
)
from corr_vars.definitions.constants import DYN_COLUMNS, COL_ORDER, PRIMARY_KEYS
from corr_vars.utils.time import TimeAnchor
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from corr_vars.core.cohort import Cohort, ObsLevel
    from corr_vars.definitions.typing import (
        VariableLoaderProtocol,
        VariableProtocol,
        ObsLevelName,
    )
    from corr_vars.utils.time import TimeWindow

# ---------------------------------------------------------------------------
# MultiSourceVariable with pre and post-processors
# ---------------------------------------------------------------------------


def validate_primary_key(var_df: pl.DataFrame, primary_key: str) -> None:
    """Validate that the primary key is known and present in the variable DataFrame.

    Args:
        var_df (pl.DataFrame): The variable data to validate.
        primary_key (str): The primary key column name.

    Raises:
        ValueError: If the primary key is not in `PRIMARY_KEYS` or absent from `var_df`.
    """

    if primary_key not in PRIMARY_KEYS:
        raise ValueError(f"Primary key {primary_key} must be one of {PRIMARY_KEYS}.")

    if primary_key not in var_df.columns:
        raise ValueError(f"Primary key {primary_key} not in data columns.")


def add_relative_times(
    reference_df: pl.DataFrame,
    var_df: pl.DataFrame,
    *,
    reference_col: str,
    join_col: str,
    suffix: str = "_relative",
    include_time_cols: Collection[str] | None = None,
) -> pl.DataFrame:
    """Add relative time columns to dynamic variable data.

    For each temporal column in `var_df` (optionally restricted to
    `include_time_cols`), computes the signed difference from `reference_col`
    in seconds and appends the result as a new column with `suffix` appended to
    the original column name. The reference column is dropped from the result.

    Args:
        reference_df (pl.DataFrame): DataFrame containing the reference time column.
        var_df (pl.DataFrame): DataFrame containing the variable data.
        reference_col (str): Column in `reference_df` holding the reference timestamp.
        join_col (str): Column used to join `reference_df` onto `var_df`.
        suffix (str): Suffix appended to each relative-time column name. Defaults to
            `"_relative"`.
        include_time_cols (Collection[str] | None): Restrict relative-time computation
            to these temporal columns. All temporal columns are used when `None`.

    Returns:
        pl.DataFrame: `var_df` with relative time columns added and `reference_col` dropped.
    """
    var_df = var_df.join(
        reference_df.select(join_col, reference_col),
        on=join_col,
        how="left",
    )

    selector = cs.temporal()
    if include_time_cols is not None:
        selector &= cs.by_name(include_time_cols, require_all=False)

    var_df = var_df.with_columns(
        selector.pipe(
            utils.time_difference,
            reference_col=reference_col,
            unit="s",
            total=True,
        ).name.suffix(suffix)
    )

    # Drop the reference time column
    var_df = var_df.drop(reference_col)

    return var_df


def unify_and_order_columns(var_df: pl.DataFrame) -> pl.DataFrame:
    """Enforce `COL_ORDER` and fill missing dynamic columns with null.

    Columns outside `COL_ORDER` that are not in `DYN_COLUMNS` trigger a
    deprecation warning and are appended after the ordered columns until the
    `attributes` migration is complete.

    Args:
        var_df (pl.DataFrame): The variable DataFrame to standardize.

    Returns:
        pl.DataFrame: `var_df` with columns reordered and missing dynamic columns added as null.
    """
    # Add missing dynamic columns (no keys) with null values
    missing_cols = [col for col in DYN_COLUMNS if col not in var_df.columns]
    if missing_cols:
        var_df = var_df.with_columns(pl.lit(None).alias(col) for col in missing_cols)

    # Order columns according to COL_ORDER, then add remaining columns
    data_cols = set(var_df.columns)
    ordered_cols = [col for col in COL_ORDER if col in data_cols]
    remaining_cols = [col for col in data_cols if col not in COL_ORDER]
    if len(remaining_cols) > 0:
        warnings.warn(
            "🚨 DEPRECATION ALERT 🚨\n"
            "Unknown columns found.\n"
            "- Please update the variable definition to move these to attributes.\n"
            "- Unknown columns are no longer supported and might be deprecated soon!\n"
            "- Columns: " + ", ".join(remaining_cols),
            UserWarning,
        )

    var_df = var_df.select(ordered_cols + remaining_cols)

    # Uncomment after user warning
    # var_df = var_df.select(ordered_cols)
    return var_df


class MultiSourceVariable:
    """Holds one :class:`~corr_vars.definitions.typing.VariableProtocol` per source and
    extracts, post-validates, and post-processes each before concatenating the results.

    Args:
        var_name (str): Shared variable name across all sources.
        variables (Mapping[str, VariableProtocol]): Mapping from source name to the
            source-specific variable implementation.
    """

    variables: Mapping[str, VariableProtocol]

    data: pl.DataFrame | None

    _var_name: str
    _dynamic: bool
    _time_window: TimeWindow

    def __init__(self, variables: Mapping[str, VariableProtocol]) -> None:
        self.variables = variables
        self.data = None
        self._pre_validate()

    def _pre_validate(self) -> None:
        if not self.variables:
            raise ValueError("No variables to concatenate")

        # Assign dynamic, tmin and tmax after veryfing variables is not empty
        first_variable = next(iter(self.variables.values()))
        self._var_name = first_variable.var_name
        self._dynamic = first_variable.dynamic
        self._time_window = first_variable.time_window

        for var in self.variables.values():
            assert (
                var.var_name == self._var_name
            ), f"Variable name mismatch: {var.var_name} != {self._var_name}"
            assert (
                var.dynamic == self._dynamic
            ), f"Dynamic status mismatch for {self._var_name}: {var.dynamic} != {self._dynamic}"
            assert (
                var.time_window == self._time_window
            ), f"Timebound mismatch for {self._var_name}: {var.time_window} != { self._time_window}"

    def extract(self, cohort: Cohort, with_source_column: bool = False) -> pl.DataFrame:
        var_data: list[pl.DataFrame] = []

        for src, var in self.variables.items():
            # Extract and validate data
            extracted = var.extract(cohort)
            self._post_validate(cohort, extracted)

            # Transform Data
            extracted = self._post_process(cohort, extracted)
            if with_source_column:
                extracted = extracted.insert_column(0, pl.lit(src).alias("data_source"))

            # Add to list
            var_data.append(extracted)
            logger.info(
                __(
                    "SUCCESS: Extracted {var_name} from {src}",
                    var_name=self._var_name,
                    src=src,
                )
            )

        # Concatenate data
        combined_data = pl.concat(var_data, how="diagonal_relaxed")
        self.data = combined_data
        return combined_data

    def _post_process(self, cohort: Cohort, var_data: pl.DataFrame) -> pl.DataFrame:
        if not self.dynamic:
            return var_data

        processed = var_data
        include_time_cols = ("recordtime", "recordtime_end")
        suffix = "_relative"
        if not frozenset(f"{col}{suffix}" for col in include_time_cols).issubset(
            frozenset(var_data.columns)
        ):
            processed = add_relative_times(
                cohort.obs,
                processed,
                reference_col=cohort.t_min,
                join_col=cohort.primary_key,
                suffix=suffix,
                include_time_cols=include_time_cols,
            )

        processed = unify_and_order_columns(processed)
        return processed

    def _post_validate(self, cohort: Cohort, var_data: pl.DataFrame) -> None:
        validate_primary_key(var_data, cohort.primary_key)

    @property
    def var_name(self) -> str:
        return self._var_name

    @var_name.setter
    def var_name(self, _: str) -> None:
        raise AttributeError("Cannot set var_name")

    @property
    def dynamic(self) -> bool:
        return self._dynamic

    @dynamic.setter
    def dynamic(self, _: bool) -> bool:
        raise AttributeError("Cannot set dynamic")

    @property
    def time_window(self) -> TimeWindow:
        return self._time_window

    @time_window.setter
    def time_window(self, _: TimeWindow) -> None:
        raise ValueError("Cannot set time_window")

    @property
    def extracted(self) -> bool:
        return self.data is not None

    def __repr__(self) -> str:
        return (
            f"MultiSourceVariable(var_name='{self._var_name}', "
            f"dynamic={self._dynamic}, "
            f"time_window={self._time_window}, "
            f"sources={list(self.variables.keys())}, "
            f"extracted={self.extracted})"
        )

    def __str__(self) -> str:
        return self.__repr__()


# ---------------------------------------------------------------------------
# load_raw_variable_configs -> dict[str, dict[str, object]]
# ---------------------------------------------------------------------------


def load_raw_variable_configs(
    include_sources: Iterable[str] | None = None,
) -> dict[str, dict[str, dict[str, object]]]:
    """Collect the `vars.json` entry from every source.

    Args:
        include_sources (Iterable[str] | None): Restrict the lookup to these sources.
            All discovered sources are searched when `None`.

    Returns:
        dict[str, dict[str, dict[str, object]]]: Mapping from source name to its full raw variable config
        dict.
    """
    var_dict: dict[str, dict[str, dict[str, Any]]] = {}

    for src, src_mapping_module in corr.get_src_module_mapping(
        include_sources=include_sources, submodule="mapping"
    ).items():
        src_VARS = getattr(src_mapping_module, "VARS")
        src_VARS_variables = src_VARS["variables"]
        var_dict[src] = src_VARS_variables

    return var_dict


# ---------------------------------------------------------------------------
# _load_variable_configs -> dict[str, dict[str, object]]
# ---------------------------------------------------------------------------


def _collect_variable_configs_by_source(
    var_name: str,
    include_sources: Iterable[str] | None = None,
) -> dict[str, dict[str, object]]:
    """Collect the `vars.json` entry for `var_name` from every source.

    Also resolves the `py` callable from `mapping/variables.py` and injects it
    into the returned config dict under the `"py"` key when a matching function exists.

    Args:
        var_name (str): Name of the variable to look up.
        include_sources (Iterable[str] | None): Restrict the lookup to these sources.
            All discovered sources are searched when `None`.

    Returns:
        dict[str, dict[str, object]]: Mapping from source name to its variable config
        dict (with `"py"` injected where applicable).
    """
    var_dict: dict[str, dict[str, Any]] = {}

    for src, src_mapping_module in corr.get_src_module_mapping(
        include_sources=include_sources, submodule="mapping"
    ).items():
        src_VARS = getattr(src_mapping_module, "VARS")
        src_VARS_variables = src_VARS["variables"]
        src_VARCODE = getattr(src_mapping_module, "variables")

        if var_name not in src_VARS_variables:
            continue

        src_vardict = src_VARS_variables[var_name].copy()

        if hasattr(src_VARCODE, var_name):
            src_vardict["py"] = getattr(src_VARCODE, var_name)

        var_dict[src] = src_vardict

    return var_dict


def _load_variable_configs(
    var_name: str,
    cohort: Cohort,
    include_sources: Iterable[str] | None = None,
) -> dict[str, dict[str, object]]:
    """Load and merge variable configurations for all requested sources.

    Retrieves the raw `vars.json` config for `var_name` via
    :func:`_collect_variable_configs_by_source`, then overlays any project-local
    overrides from :meth:`~corr_vars.core.cohort.Cohort.get_variable_definition`.

    Args:
        var_name (str): Name of the variable to load configuration for.
        cohort (Cohort): Cohort whose project overrides are applied.
        include_sources (Iterable[str] | None): Sources to consider.
            All sources are used when `None`.

    Returns:
        dict[str, dict[str, object]]: Merged per-source variable configuration mapping.
    """

    # Get variable configuration for each source defined in cohort
    variable_configs_by_source = _collect_variable_configs_by_source(
        var_name=var_name,
        include_sources=include_sources,
    )

    # Apply project specific overrides
    override = cohort.get_variable_definition(var_name)
    variable_configs_by_source = utils.deep_merge(
        base=variable_configs_by_source, override=override
    )
    return variable_configs_by_source


# ---------------------------------------------------------------------------
# load_variable -> MultiSourceVariable
# ---------------------------------------------------------------------------


def _transfer_time_window_from_variable_config(
    variable_configs_by_source: dict[str, dict[str, Any]],
    time_window: TimeWindow,
) -> tuple[dict[str, dict[str, object]], TimeWindow]:
    """Extract `tmin`/`tmax` overrides from source configs into the time window.

    If any source config contains a `"tmin"` or `"tmax"` key, those values are
    popped from the config and written to a copy of `time_window`, overriding whatever
    was passed in. A warning is logged for each override.

    Args:
        variable_configs_by_source (dict[str, dict[str, Any]]): Per-source config dicts.
        time_window (TimeWindow): The original time window.

    Returns:
        tuple[dict[str, dict[str, object]], TimeWindow]: Deep copies of the config dict
        and time window, updated with any embedded overrides.
    """
    time_window_cp = copy.copy(time_window)
    variable_configs_by_source_cp = copy.deepcopy(variable_configs_by_source)
    for _, var_config in variable_configs_by_source_cp.items():
        if "tmin" in var_config:
            logger.warning(
                __(
                    "For this variable tmin is already frozen to {tmin}. Any other tmin will be ignored.",
                    tmin=var_config["tmin"],
                )
            )
            tmin = var_config.pop("tmin")
            time_window_cp.tmin = TimeAnchor(tmin)

        if "tmax" in var_config:
            logger.warning(
                __(
                    "For this variable tmax is already frozen to {tmax}. Any other tmax will be ignored.",
                    tmax=var_config["tmax"],
                )
            )
            tmax = var_config.pop("tmax")
            time_window_cp.tmax = TimeAnchor(tmax)

    return variable_configs_by_source_cp, time_window_cp


def _ensure_compatibility(
    var_name: str,
    variable_configs_by_source: dict[str, dict[str, Any]],
    obs_level: ObsLevel,
) -> None:
    """Raise if `var_name` is incompatible with the cohort's observation level.

    Pops `"compatible_with"` from each source config dict (mutates in-place) and
    raises :exc:`~corr_vars.definitions.exceptions.ObsLevelNotSupportedError` if the
    cohort's obs level is not listed.

    Args:
        var_name (str): Variable name (used in the error message).
        variable_configs_by_source (dict[str, dict[str, Any]]): Per-source config dicts.
        obs_level (ObsLevel): The cohort's current observation level.

    Raises:
        ObsLevelNotSupportedError: If `obs_level` is not in `compatible_with`.
    """
    for _, var_config in variable_configs_by_source.items():
        if "compatible_with" in var_config:
            compatible_obs_level: list[ObsLevelName] = var_config.pop("compatible_with")
            if obs_level.lower_name not in compatible_obs_level:
                raise ObsLevelNotSupportedError(
                    f"Obs level {obs_level.lower_name} is not compatible with {var_name}"
                )


def _log_variable_loader_args_kwargs(
    var_name: str,
    time_window: TimeWindow,
    var_config: dict[str, object],
) -> None:
    """Log the arguments that will be passed to a source's `VariableLoader`.

    Args:
        var_name (str): Variable name being created.
        time_window (TimeWindow): Time window used for extraction.
        var_config (dict[str, object]): Keyword arguments forwarded to the loader.
    """
    logger.info(__("Creating variable {var_name} with:", var_name=var_name))
    prefix = "└──" if not var_config else "├──"
    logger.info(
        __(
            "{prefix} time_window: {time_window}",
            prefix=prefix,
            time_window=time_window,
        )
    )
    if var_config:
        logger.info("└── kwargs:")
        utils.log_dict(logger=logger, dictionary=var_config, indent=4, json_indent=2)


def _load_variables_from_config(
    var_name: str,
    variable_configs_by_source: dict[str, dict[str, Any]],
    cohort: Cohort,
    time_window: TimeWindow,
) -> dict[str, VariableProtocol]:
    """Instantiate a source-specific variable for each entry in `variable_configs_by_source`.

    For each source, retrieves the source's `VariableLoader` factory, calls it with
    `var_name`, `time_window`, and the source's config dict as keyword arguments,
    and stores the resulting :class:`~corr_vars.definitions.typing.VariableProtocol` in
    the output mapping.

    Args:
        var_name (str): Name of the variable to instantiate.
        variable_configs_by_source (dict[str, dict[str, Any]]): Per-source config dicts
            whose keys determine which sources are visited.
        cohort (Cohort): The cohort (used to skip sources not in `cohort.sources`).
        time_window (TimeWindow): Time window forwarded to every `VariableLoader`.

    Returns:
        dict[str, VariableProtocol]: Mapping from source name to the instantiated variable.
    """

    src_var_mapping: dict[str, VariableProtocol] = {}

    # Load from other sources
    for src, src_module in corr.get_src_module_mapping(
        include_sources=variable_configs_by_source.keys()
    ).items():
        if src not in cohort.sources:
            logger.info(
                __("Source {src} not requested in cohort. Skipping...", src=src)
            )

        # Load kwargs
        var_config = variable_configs_by_source[src]
        _log_variable_loader_args_kwargs(
            var_name=var_name, time_window=time_window, var_config=var_config
        )

        # Get VariableLoaderProtocol (adapter) to load VariableProtocol
        SrcVariableLoader: VariableLoaderProtocol = getattr(
            src_module, "VariableLoader"
        )
        SrcVariable: VariableProtocol = SrcVariableLoader.__call__(
            var_name=var_name, time_window=time_window, **var_config
        )
        src_var_mapping.setdefault(src, SrcVariable)

        logger.info(
            __(
                "Variable {var_name} found in source {src}",
                var_name=var_name,
                src=src,
            )
        )

    return src_var_mapping


[docs] def load_variable( var_name: str, cohort: Cohort, time_window: TimeWindow, include_sources: Iterable[str] | None = None, ) -> MultiSourceVariable: """Load `var_name` from all configured sources and wrap it in a :class:`MultiSourceVariable`. Orchestrates config loading, project-override merging, time-window transfer, compatibility checks, and per-source variable instantiation. Args: var_name (str): Name of the variable to load. cohort (Cohort): Cohort that supplies source configs and project overrides. time_window (TimeWindow): Time window for variable extraction. include_sources (Iterable[str] | None): Restrict to these sources. All cohort sources are used when `None`. Returns: MultiSourceVariable: Variable ready for extraction. Raises: VariableNotFoundError: If `var_name` is not found in any source config, or if no source variable could be instantiated. """ # Get lookup sources sources = include_sources or cohort.sources.keys() # Get variable configuration for each source defined in cohort variable_configs_by_source = _load_variable_configs( var_name=var_name, cohort=cohort, include_sources=sources ) # Raise error if no source was found if not variable_configs_by_source: raise VariableNotFoundError(f"Variable {var_name} not found in JSON") # Debug logging logger.debug(__("Loaded config for {var_name}:", var_name=var_name)) utils.log_dict( logger, variable_configs_by_source, level=logging.DEBUG, json_indent=2 ) # Transfers "tmin", "tmax" from dict to time_window (pops from config; overrides time_window) ( variable_configs_by_source, time_window, ) = _transfer_time_window_from_variable_config( variable_configs_by_source, time_window ) # STOP if not compatible with obs level _ensure_compatibility(var_name, variable_configs_by_source, cohort.obs_level) # Get variable for each variable configuration by source variable_by_source = _load_variables_from_config( var_name=var_name, variable_configs_by_source=variable_configs_by_source, cohort=cohort, time_window=time_window, ) # Raise error if no variable was found if not variable_by_source: raise VariableNotFoundError( f"Variable {var_name} not found in any source (looked in {sources})" ) return MultiSourceVariable(variable_by_source)
# --------------------------------------------------------------------------- # load_default_variables -> list[MultiSourceVariable] # --------------------------------------------------------------------------- def _collect_default_variable_list_by_source( obs_level: ObsLevel, include_sources: Iterable[str] | None = None, ) -> dict[str, list[str]]: """Collect the default variable list from each source's `vars.json`. Reads `corr_defaults.default_vars` from each source's `VARS` dict, combining the `"global"` list with the obs-level-specific list. Args: obs_level (ObsLevel): Observation level used to select level-specific defaults. include_sources (Iterable[str] | None): Restrict to these sources. All discovered sources are checked when `None`. Returns: dict[str, list[str]]: Mapping from source name to its list of default variable names. """ src_dict: dict[str, list[str]] = {} for src, src_mapping_module in corr.get_src_module_mapping( include_sources=include_sources, submodule="mapping" ).items(): src_VARS = getattr(src_mapping_module, "VARS") if not ( "corr_defaults" in src_VARS and "default_vars" in src_VARS["corr_defaults"] ): continue default_vars: dict[str, list[str]] = src_VARS["corr_defaults"]["default_vars"] global_defaults = default_vars.get("global", []) obs_level_defaults = default_vars.get(obs_level.lower_name, []) all_defaults = global_defaults + obs_level_defaults src_dict.setdefault(src, all_defaults) return src_dict def load_default_variables( cohort: Cohort, time_window: TimeWindow ) -> list[MultiSourceVariable]: """Load the default variables for the cohort's observation level. Aggregates the default variable lists from all cohort sources, then loads each variable restricted to the sources that declared it. Args: cohort (Cohort): Cohort whose observation level and sources determine which defaults are loaded. time_window (TimeWindow): Time window forwarded to every variable loader. Returns: list[MultiSourceVariable]: One :class:`MultiSourceVariable` per unique default variable name, ready for extraction. """ # Get default variables per source source_default_variables_mapping = _collect_default_variable_list_by_source( obs_level=cohort.obs_level, include_sources=cohort.sources.keys() ) # Invert mapping variable_source_mapping: dict[str, list[str]] = {} for src, default_variables in source_default_variables_mapping.items(): for var_name in default_variables: variable_source_mapping.setdefault(var_name, []) variable_source_mapping[var_name].append(src) # Load variable for included sources only variables: list[MultiSourceVariable] = [] for var_name, include_sources in variable_source_mapping.items(): # TODO: Loading only included sources can be tricky, # if the variable is loaded in other sources later variable = load_variable( var_name=var_name, cohort=cohort, time_window=time_window, include_sources=include_sources, ) variables.append(variable) return variables