Source code for corr_vars.sources.var_loader

from __future__ import annotations

import importlib
import os
import pkgutil

import polars as pl

from corr_vars import logger
import corr_vars.sources as sources
import corr_vars.utils as utils

from collections.abc import Iterable, Mapping
from types import ModuleType
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from corr_vars.core.cohort import Cohort
    from corr_vars.core.variable import Variable, TimeWindow, TimeAnchor
    from corr_vars.definitions import VariableProtocol

# Only get submodules that have a __init__.py file -> real sources
SOURCES = [
    name
    for _, name, ispkg in pkgutil.iter_modules(sources.__path__)
    if ispkg
    and os.path.isdir(os.path.join(os.path.dirname(sources.__file__), name))
    and os.path.isfile(
        os.path.join(os.path.dirname(sources.__file__), name, "__init__.py")
    )
]


class MultiSourceVariable:
    """Concatenate variables from multiple sources.

    Args:
        var_name (str): The name of the variable to concatenate.
        list[Variable]: The list of variables to concatenate.
    """

    var_name: str
    variables: Mapping[str, VariableProtocol]

    data: pl.DataFrame | None

    _dynamic: bool
    _time_window: TimeWindow

    def __init__(
        self, var_name: str, variables: Mapping[str, VariableProtocol]
    ) -> None:
        self.var_name = var_name
        self.variables = variables
        self._sanity_check()

    def _sanity_check(self) -> None:
        if not self.variables:
            raise ValueError("No variables to concatenate")

        # Assign dynamic, tmin and tmax after veryfing variables is not empty
        first_variable = next(iter(self.variables.values()))
        self._dynamic = first_variable.dynamic
        self._time_window = first_variable.time_window

        for var in self.variables.values():
            assert (
                var.var_name == self.var_name
            ), f"Variable name mismatch: {var.var_name} != {self.var_name}"
            assert (
                var.dynamic == self._dynamic
            ), f"Dynamic status mismatch for {self.var_name}: {var.dynamic} != {self._dynamic}"
            assert (
                var.time_window == self._time_window
            ), f"Timebound mismatch for {self.var_name}: {var.time_window} != { self._time_window}"

    def extract(self, cohort: Cohort, with_source_column: bool = False) -> pl.DataFrame:
        var_data = []

        for src, var in self.variables.items():
            extracted = var.extract(cohort)
            if with_source_column:
                extracted = extracted.insert_column(0, pl.lit(src).alias("data_source"))
            var_data.append(extracted)
            logger.info(f"SUCCESS: Extracted {self.var_name} from {src}")

        combined_data = pl.concat(var_data, how="diagonal_relaxed")
        self.data = combined_data
        return combined_data

    @property
    def dynamic(self) -> bool:
        return self._dynamic

    @dynamic.setter
    def dynamic(self, _: bool) -> bool:
        raise AttributeError("Cannot set dynamic")

    @property
    def time_window(self) -> TimeWindow:
        return self._time_window

    @time_window.setter
    def time_window(self, time_window: TimeWindow) -> None:
        if self.extracted:
            raise ValueError("Cannot set time_window after extraction")

        self._time_window = time_window
        for var in self.variables.values():
            var.time_window = time_window

    @property
    def extracted(self) -> bool:
        return self.data is not None

    def __repr__(self) -> str:
        return (
            f"MultiSourceVariable(var_name='{self.var_name}', "
            f"dynamic={self._dynamic}, "
            f"time_window={self._time_window}, "
            f"sources={list(self.variables.keys())}, "
            f"extracted={self.extracted})"
        )

    def __str__(self) -> str:
        return self.__repr__()


def _get_src_module(source: str, submodule: str | None = None) -> ModuleType:
    sourced_module = f"corr_vars.sources.{source}"
    return importlib.import_module(
        f"{sourced_module}.{submodule}" if submodule else sourced_module
    )


def _get_src_module_mapping(
    include_sources: Iterable[str] | None = None, submodule: str | None = None
) -> dict[str, ModuleType]:
    if include_sources is None:
        include_sources = SOURCES
    return {
        src: _get_src_module(source=src, submodule=submodule) for src in include_sources
    }


def _get_overrides(var_name: str, cohort: Cohort) -> dict[str, dict[str, Any]]:
    """Get the overrides for a variable.

    Args:
        var_name (str): Name of the variable to get overrides for.
        cohort (Cohort): The cohort to get overrides for.

    Returns:
        overrides (dict[str, dict[str, Any]]): Dictionary of overrides.
    """
    override = {}
    # Project vars
    if hasattr(cohort, "project_vars"):
        # Legacy mode -> Update all existing sources
        if var_name in cohort.project_vars:
            override = {src: cohort.project_vars[var_name] for src in cohort.sources}
        else:
            # New implementation with source specification
            override = {
                src: cohort.project_vars[var_name]
                for src in cohort.sources
                if var_name in cohort.project_vars.get(src, dict())
            }
    return override


def load_default_variables(
    cohort: Cohort, time_window: TimeWindow
) -> list[MultiSourceVariable]:
    """Load the default variables for a given observation level.

    Args:
        cohort (Cohort): The Cohort object to load variables for.
        time_window (TimeWindow): Minimum and maximum time for extraction.

    Returns:
        list[MultiSourceVariable]: List of MultiSourceVariable objects configured for the observation level.
    """

    variable_source_mapping: dict[str, list[str]] = {}
    for src, src_module in _get_src_module_mapping().items():
        src_mapping_module = getattr(src_module, "mapping")
        src_VARS = getattr(src_mapping_module, "VARS")

        if not (
            "corr_defaults" in src_VARS and "default_vars" in src_VARS["corr_defaults"]
        ):
            continue

        default_vars = src_VARS["corr_defaults"]["default_vars"]
        global_defaults = default_vars.get("global", [])
        obs_level_defaults = default_vars.get(cohort.obs_level.lower_name, [])
        all_defaults = global_defaults + obs_level_defaults

        for var_name in all_defaults:
            variable_source_mapping.setdefault(var_name, [])
            variable_source_mapping[var_name].append(src)

    variables = []
    for var_name, include_sources in variable_source_mapping.items():
        override = _get_overrides(var_name, cohort)
        var_src_config = _get_variable_src_config(
            var_name, include_sources=include_sources, override=override
        )
        variable = _load_variable_from_config(
            var_name, var_src_config, cohort, time_window
        )
        variables.append(variable)

    return variables


def _get_variable_src_config(
    var_name: str,
    include_sources: Iterable[str] | None = None,
    override: dict[str, dict[str, Any]] | None = None,
) -> dict[str, dict[str, object]]:
    """Get the source configuration for a variable.

    Args:
        var_name (str): Name of the variable to get configuration for.
        include_sources: (list[str] | None): The sources to consider for the config. Will include all sources if set to None.
        override (dict[str, dict[str, Any]] | None): A dictionary used to update the configuration dictionary.

    Returns:
        dict: Configuration dictionary for the variable.

    Raises:
        ValueError: If variable is not found in configuration.
    """
    var_dict: dict[str, dict[str, Any]] = {}

    for src, src_mapping_module in _get_src_module_mapping(
        include_sources=include_sources, submodule="mapping"
    ).items():
        src_VARS = getattr(src_mapping_module, "VARS")
        src_VARS_variables = src_VARS["variables"]
        src_VARCODE = getattr(src_mapping_module, "variables")

        if var_name not in src_VARS_variables:
            continue

        src_vardict = src_VARS_variables[var_name].copy()

        if hasattr(src_VARCODE, var_name):
            src_vardict["py"] = getattr(src_VARCODE, var_name)

        var_dict[src] = src_vardict

    if not var_dict:
        raise ValueError(f"Variable {var_name} not found")

    logger.debug(f"Loaded config for {var_name}:")
    utils.log_dict(logger, var_dict, level=utils.logging.DEBUG, json_indent=2)

    # Project vars
    if override:
        for src in var_dict:
            var_dict[src].update(override.get(src, dict()))

    return var_dict


[docs] def load_variable( var_name: str, cohort: Cohort, time_window: TimeWindow ) -> MultiSourceVariable: """Iterates through sources and gets variable from each source. Args: var_name (str): The name of the variable to load. cohort (Cohort): The cohort to load the variable for. tmin (TimeBoundColumn | None): Minimum time for extraction (default: None). tmax (TimeBoundColumn | None): Maximum time for extraction (default: None). Returns: MultiSourceVariable: The loaded variable ready for extraction. """ override = _get_overrides(var_name, cohort) var_src_config = _get_variable_src_config( var_name, include_sources=cohort.sources.keys(), override=override ) return _load_variable_from_config(var_name, var_src_config, cohort, time_window)
def _load_variable_from_config( var_name: str, var_src_config: dict[str, dict[str, Any]], cohort: Cohort, time_window: TimeWindow, ) -> MultiSourceVariable: """Iterates through sources and gets variable from each source. Args: var_name (str): The name of the variable to load. var_src_config (dict[str, dict[str, Any]]): The source configuration for the variable. cohort (Cohort): The cohort to load the variable for. time_window (TimeWindow): Time window for extraction. Returns: MultiSourceVariable: The loaded variable. """ src_var_mapping: dict[str, Variable] = {} # Load from other sources for src, src_module in _get_src_module_mapping( include_sources=var_src_config.keys() ).items(): if src in cohort.sources: var_config = var_src_config[src] SrcVariable = getattr(src_module, "Variable") if "tmin" in var_config: logger.warning( f"For this variable tmin is already frozen to {var_config['tmin']}. Any other tmin will be ignored.", ) tmin = var_config.pop("tmin") time_window.tmin = TimeAnchor(tmin) if "tmax" in var_config: logger.warning( f"For this variable tmax is already frozen to {var_config['tmax']}. Any other tmax will be ignored.", ) tmax = var_config.pop("tmax") time_window.tmax = TimeAnchor(tmax) src_var_mapping.setdefault( src, SrcVariable.from_time_window( var_name=var_name, time_window=time_window, **var_config ), ) logger.info(f"Variable {var_name} found in source {src}") else: logger.info(f"Source {src} not requested in cohort. Skipping...") if not src_var_mapping: raise ValueError( f"Variable {var_name} not found in any source (looked in {cohort.sources.keys()})" ) return MultiSourceVariable(var_name, src_var_mapping)