Source code for corr_vars.sources.var_loader

from __future__ import annotations

import importlib
import os
import pkgutil
import warnings

import polars as pl

from corr_vars import logger
import corr_vars.sources as sources
import corr_vars.utils as utils

from collections.abc import Iterable
from types import ModuleType
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from corr_vars.core.cohort import Cohort
    from corr_vars.core.variable import Variable
    from corr_vars.definitions import TimeBoundColumn

# Only get submodules that have a __init__.py file -> real sources
SOURCES = [
    name
    for _, name, ispkg in pkgutil.iter_modules(sources.__path__)
    if ispkg
    and os.path.isdir(os.path.join(os.path.dirname(sources.__file__), name))
    and os.path.isfile(
        os.path.join(os.path.dirname(sources.__file__), name, "__init__.py")
    )
]


class MultiSourceVariable:
    """Concatenate variables from multiple sources.

    Args:
        var_name (str): The name of the variable to concatenate.
        list[Variable]: The list of variables to concatenate.
    """

    var_name: str
    dynamic: bool
    data: pl.DataFrame | None

    tmin: TimeBoundColumn | None
    tmax: TimeBoundColumn | None

    variables: dict[str, Variable]

    def __init__(self, var_name: str, variables: dict[str, Variable]) -> None:
        self.var_name = var_name
        self.variables = variables
        self._sanity_check()

    def _sanity_check(self) -> None:
        if not self.variables:
            raise ValueError("No variables to concatenate")

        # Assign dynamic, tmin and tmax after veryfing variables is not empty
        self.dynamic = next(iter(self.variables.values())).dynamic
        self.tmin = next(iter(self.variables.values())).tmin
        self.tmax = next(iter(self.variables.values())).tmax

        for var in self.variables.values():
            assert (
                var.var_name == self.var_name
            ), f"Variable name mismatch: {var.var_name} != {self.var_name}"
            assert (
                var.dynamic == self.dynamic
            ), f"Dynamic status mismatch for {self.var_name}: {var.dynamic} != {self.dynamic}"
            assert (
                var.tmin == self.tmin
            ), f"tmin mismatch for {self.var_name}: {var.tmin} != {self.tmin}"
            assert (
                var.tmax == self.tmax
            ), f"tmax mismatch for {self.var_name}: {var.tmax} != {self.tmax}"

    def extract(self, cohort: Cohort, with_source_column: bool = False) -> pl.DataFrame:
        var_data = []

        for src, var in self.variables.items():
            extracted = var.extract(cohort)
            if with_source_column:
                extracted = extracted.insert_column(0, pl.lit(src).alias("data_source"))
            var_data.append(extracted)
            logger.info(f"SUCCESS: Extracted {self.var_name} from {src}")

        combined_data = pl.concat(var_data, how="diagonal_relaxed")
        self.data = combined_data
        return combined_data


def _get_src_module(source: str, submodule: str | None = None) -> ModuleType:
    sourced_module = f"corr_vars.sources.{source}"
    return importlib.import_module(
        f"{sourced_module}.{submodule}" if submodule else sourced_module
    )


def _get_src_module_mapping(
    include_sources: Iterable[str] | None = None, submodule: str | None = None
) -> dict[str, ModuleType]:
    if include_sources is None:
        include_sources = SOURCES
    return {
        src: _get_src_module(source=src, submodule=submodule) for src in include_sources
    }


def _get_overrides(var_name: str, cohort: Cohort) -> dict[str, dict[str, Any]]:
    """Get the overrides for a variable.

    Args:
        var_name (str): Name of the variable to get overrides for.
        cohort (Cohort): The cohort to get overrides for.

    Returns:
        overrides (dict[str, dict[str, Any]]): Dictionary of overrides.
    """
    override = {}
    # Project vars
    if hasattr(cohort, "project_vars"):
        # Legacy mode -> Update all existing sources
        if var_name in cohort.project_vars:
            override = {src: cohort.project_vars[var_name] for src in cohort.sources}
        else:
            # New implementation with source specification
            override = {
                src: cohort.project_vars[var_name]
                for src in cohort.sources
                if var_name in cohort.project_vars.get(src, dict())
            }
    return override


def load_default_variables(
    cohort: Cohort,
    tmin: TimeBoundColumn | None = ("hospital_admission", "-48h"),
    tmax: TimeBoundColumn | None = ("hospital_admission", "+48h"),
) -> list[MultiSourceVariable]:
    """Load the default variables for a given observation level.

    Args:
        obs_level (str): The observation level ("icu_stay", "hospital_stay", or "procedure").
        tmin (TimeBoundColumn | None): Minimum time for extraction (default: ("hospital_admission", "-48h")).
        tmax (TimeBoundColumn | None): Maximum time for extraction (default: ("hospital_admission", "+48h")).
        on_admission_select (str): Selection clause for admission variables (default: "!closest(hospital_admission, 0, 48h) value").

    Returns:
        list[Variable]: List of Variable objects configured for the observation level.
    """

    variable_source_mapping: dict[str, list[str]] = {}
    for src, src_module in _get_src_module_mapping().items():
        src_mapping_module = getattr(src_module, "mapping")
        src_VARS = getattr(src_mapping_module, "VARS")

        if not (
            "corr_defaults" in src_VARS and "default_vars" in src_VARS["corr_defaults"]
        ):
            continue

        default_vars = src_VARS["corr_defaults"]["default_vars"]
        global_defaults = default_vars.get("global", [])
        obs_level_defaults = default_vars.get(cohort.obs_level.lower_name, [])
        all_defaults = global_defaults + obs_level_defaults

        for var_name in all_defaults:
            variable_source_mapping.setdefault(var_name, [])
            variable_source_mapping[var_name].append(src)

    variables = []
    for var_name, include_sources in variable_source_mapping.items():
        override = _get_overrides(var_name, cohort)
        var_src_config = _get_variable_src_config(
            var_name, include_sources=include_sources, override=override
        )
        variable = _load_variable_from_config(
            var_name, var_src_config, cohort, tmin, tmax
        )
        variables.append(variable)

    return variables


def _get_variable_src_config(
    var_name: str,
    include_sources: Iterable[str] | None = None,
    override: dict[str, dict[str, Any]] | None = None,
) -> dict[str, dict[str, Any]]:
    """Get the source configuration for a variable.

    Args:
        var_name (str): Name of the variable to get configuration for.
        include_sources: (list[str] | None): The sources to consider for the config. Will include all sources if set to None.
        override (dict[str, dict[str, Any]] | None): A dictionary used to update the configuration dictionary.

    Returns:
        dict: Configuration dictionary for the variable.

    Raises:
        ValueError: If variable is not found in configuration.
    """
    var_dict: dict[str, dict[str, Any]] = {}

    for src, src_mapping_module in _get_src_module_mapping(
        include_sources=include_sources, submodule="mapping"
    ).items():
        src_VARS = getattr(src_mapping_module, "VARS")
        src_VARS_variables = src_VARS["variables"]
        src_VARCODE = getattr(src_mapping_module, "variables")

        if var_name not in src_VARS_variables:
            continue

        src_vardict = src_VARS_variables[var_name].copy()

        if hasattr(src_VARCODE, var_name):
            src_vardict["py"] = getattr(src_VARCODE, var_name)

        var_dict[src] = src_vardict

    if not var_dict:
        raise ValueError(f"Variable {var_name} not found")

    logger.debug(f"Loaded config for {var_name}:")
    utils.log_dict(logger, var_dict, level=utils.logging.DEBUG)

    # Project vars
    if override:
        for src in var_dict:
            var_dict[src].update(override.get(src, dict()))

    return var_dict


[docs] def load_variable( var_name: str, cohort: Cohort, tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, ) -> MultiSourceVariable: """Iterates through sources and gets variable from each source. Args: var_name (str): The name of the variable to load. cohort (Cohort): The cohort to load the variable for. tmin (TimeBoundColumn | None): Minimum time for extraction (default: None). tmax (TimeBoundColumn | None): Maximum time for extraction (default: None). Returns: MultiSourceVariable: The loaded variable ready for extraction. """ override = _get_overrides(var_name, cohort) var_src_config = _get_variable_src_config( var_name, include_sources=cohort.sources.keys(), override=override ) return _load_variable_from_config(var_name, var_src_config, cohort, tmin, tmax)
def _load_variable_from_config( var_name: str, var_src_config: dict[str, dict[str, Any]], cohort: Cohort, tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, ) -> MultiSourceVariable: """Iterates through sources and gets variable from each source. Args: var_name (str): The name of the variable to load. var_src_config (dict): The source configuration for the variable. cohort (Cohort): The cohort to load the variable for. tmin (TimeBoundColumn | None): Minimum time for extraction (default: None). tmax (TimeBoundColumn | None): Maximum time for extraction (default: None). Returns: ExtractedVariable: The extracted variable. """ src_var_mapping: dict[str, Variable] = {} # Load from other sources for src, src_module in _get_src_module_mapping( include_sources=var_src_config.keys() ).items(): if src in cohort.sources: var_config = var_src_config[src] SrcVariable = getattr(src_module, "Variable") if "tmin" in var_config: if tmin is not None: warnings.warn( f"For this variable tmin is already set to {var_config['tmin']}. tmin will be ignored.", UserWarning, ) else: var_config.setdefault("tmin", tmin) if "tmax" in var_config: if tmax is not None: warnings.warn( f"For this variable tmax is already set to {var_config['tmax']}. tmax will be ignored.", UserWarning, ) else: var_config.setdefault("tmax", tmax) src_var_mapping.setdefault( src, SrcVariable(var_name=var_name, **var_config) ) logger.info(f"Variable {var_name} found in source {src}") else: logger.info(f"Source {src} not requested in cohort. Skipping...") if not src_var_mapping: raise ValueError( f"Variable {var_name} not found in any source (looked in {cohort.sources.keys()})" ) return MultiSourceVariable(var_name, src_var_mapping)