from __future__ import annotations
import importlib
import os
import pkgutil
import warnings
import polars as pl
from corr_vars import logger
import corr_vars.sources as sources
import corr_vars.utils as utils
from collections.abc import Iterable
from types import ModuleType
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from corr_vars.core.cohort import Cohort
from corr_vars.core.variable import Variable
from corr_vars.definitions import TimeBoundColumn
# Only get submodules that have a __init__.py file -> real sources
SOURCES = [
name
for _, name, ispkg in pkgutil.iter_modules(sources.__path__)
if ispkg
and os.path.isdir(os.path.join(os.path.dirname(sources.__file__), name))
and os.path.isfile(
os.path.join(os.path.dirname(sources.__file__), name, "__init__.py")
)
]
class MultiSourceVariable:
"""Concatenate variables from multiple sources.
Args:
var_name (str): The name of the variable to concatenate.
list[Variable]: The list of variables to concatenate.
"""
var_name: str
dynamic: bool
data: pl.DataFrame | None
tmin: TimeBoundColumn | None
tmax: TimeBoundColumn | None
variables: dict[str, Variable]
def __init__(self, var_name: str, variables: dict[str, Variable]) -> None:
self.var_name = var_name
self.variables = variables
self._sanity_check()
def _sanity_check(self) -> None:
if not self.variables:
raise ValueError("No variables to concatenate")
# Assign dynamic, tmin and tmax after veryfing variables is not empty
self.dynamic = next(iter(self.variables.values())).dynamic
self.tmin = next(iter(self.variables.values())).tmin
self.tmax = next(iter(self.variables.values())).tmax
for var in self.variables.values():
assert (
var.var_name == self.var_name
), f"Variable name mismatch: {var.var_name} != {self.var_name}"
assert (
var.dynamic == self.dynamic
), f"Dynamic status mismatch for {self.var_name}: {var.dynamic} != {self.dynamic}"
assert (
var.tmin == self.tmin
), f"tmin mismatch for {self.var_name}: {var.tmin} != {self.tmin}"
assert (
var.tmax == self.tmax
), f"tmax mismatch for {self.var_name}: {var.tmax} != {self.tmax}"
def extract(self, cohort: Cohort, with_source_column: bool = False) -> pl.DataFrame:
var_data = []
for src, var in self.variables.items():
extracted = var.extract(cohort)
if with_source_column:
extracted = extracted.insert_column(0, pl.lit(src).alias("data_source"))
var_data.append(extracted)
logger.info(f"SUCCESS: Extracted {self.var_name} from {src}")
combined_data = pl.concat(var_data, how="diagonal_relaxed")
self.data = combined_data
return combined_data
def _get_src_module(source: str, submodule: str | None = None) -> ModuleType:
sourced_module = f"corr_vars.sources.{source}"
return importlib.import_module(
f"{sourced_module}.{submodule}" if submodule else sourced_module
)
def _get_src_module_mapping(
include_sources: Iterable[str] | None = None, submodule: str | None = None
) -> dict[str, ModuleType]:
if include_sources is None:
include_sources = SOURCES
return {
src: _get_src_module(source=src, submodule=submodule) for src in include_sources
}
def _get_overrides(var_name: str, cohort: Cohort) -> dict[str, dict[str, Any]]:
"""Get the overrides for a variable.
Args:
var_name (str): Name of the variable to get overrides for.
cohort (Cohort): The cohort to get overrides for.
Returns:
overrides (dict[str, dict[str, Any]]): Dictionary of overrides.
"""
override = {}
# Project vars
if hasattr(cohort, "project_vars"):
# Legacy mode -> Update all existing sources
if var_name in cohort.project_vars:
override = {src: cohort.project_vars[var_name] for src in cohort.sources}
else:
# New implementation with source specification
override = {
src: cohort.project_vars[var_name]
for src in cohort.sources
if var_name in cohort.project_vars.get(src, dict())
}
return override
def load_default_variables(
cohort: Cohort,
tmin: TimeBoundColumn | None = ("hospital_admission", "-48h"),
tmax: TimeBoundColumn | None = ("hospital_admission", "+48h"),
) -> list[MultiSourceVariable]:
"""Load the default variables for a given observation level.
Args:
obs_level (str): The observation level ("icu_stay", "hospital_stay", or "procedure").
tmin (TimeBoundColumn | None): Minimum time for extraction (default: ("hospital_admission", "-48h")).
tmax (TimeBoundColumn | None): Maximum time for extraction (default: ("hospital_admission", "+48h")).
on_admission_select (str): Selection clause for admission variables (default: "!closest(hospital_admission, 0, 48h) value").
Returns:
list[Variable]: List of Variable objects configured for the observation level.
"""
variable_source_mapping: dict[str, list[str]] = {}
for src, src_module in _get_src_module_mapping().items():
src_mapping_module = getattr(src_module, "mapping")
src_VARS = getattr(src_mapping_module, "VARS")
if not (
"corr_defaults" in src_VARS and "default_vars" in src_VARS["corr_defaults"]
):
continue
default_vars = src_VARS["corr_defaults"]["default_vars"]
global_defaults = default_vars.get("global", [])
obs_level_defaults = default_vars.get(cohort.obs_level.lower_name, [])
all_defaults = global_defaults + obs_level_defaults
for var_name in all_defaults:
variable_source_mapping.setdefault(var_name, [])
variable_source_mapping[var_name].append(src)
variables = []
for var_name, include_sources in variable_source_mapping.items():
override = _get_overrides(var_name, cohort)
var_src_config = _get_variable_src_config(
var_name, include_sources=include_sources, override=override
)
variable = _load_variable_from_config(
var_name, var_src_config, cohort, tmin, tmax
)
variables.append(variable)
return variables
def _get_variable_src_config(
var_name: str,
include_sources: Iterable[str] | None = None,
override: dict[str, dict[str, Any]] | None = None,
) -> dict[str, dict[str, Any]]:
"""Get the source configuration for a variable.
Args:
var_name (str): Name of the variable to get configuration for.
include_sources: (list[str] | None): The sources to consider for the config. Will include all sources if set to None.
override (dict[str, dict[str, Any]] | None): A dictionary used to update the configuration dictionary.
Returns:
dict: Configuration dictionary for the variable.
Raises:
ValueError: If variable is not found in configuration.
"""
var_dict: dict[str, dict[str, Any]] = {}
for src, src_mapping_module in _get_src_module_mapping(
include_sources=include_sources, submodule="mapping"
).items():
src_VARS = getattr(src_mapping_module, "VARS")
src_VARS_variables = src_VARS["variables"]
src_VARCODE = getattr(src_mapping_module, "variables")
if var_name not in src_VARS_variables:
continue
src_vardict = src_VARS_variables[var_name].copy()
if hasattr(src_VARCODE, var_name):
src_vardict["py"] = getattr(src_VARCODE, var_name)
var_dict[src] = src_vardict
if not var_dict:
raise ValueError(f"Variable {var_name} not found")
logger.debug(f"Loaded config for {var_name}:")
utils.log_dict(logger, var_dict, level=utils.logging.DEBUG)
# Project vars
if override:
for src in var_dict:
var_dict[src].update(override.get(src, dict()))
return var_dict
[docs]
def load_variable(
var_name: str,
cohort: Cohort,
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
) -> MultiSourceVariable:
"""Iterates through sources and gets variable from each source.
Args:
var_name (str): The name of the variable to load.
cohort (Cohort): The cohort to load the variable for.
tmin (TimeBoundColumn | None): Minimum time for extraction (default: None).
tmax (TimeBoundColumn | None): Maximum time for extraction (default: None).
Returns:
MultiSourceVariable: The loaded variable ready for extraction.
"""
override = _get_overrides(var_name, cohort)
var_src_config = _get_variable_src_config(
var_name, include_sources=cohort.sources.keys(), override=override
)
return _load_variable_from_config(var_name, var_src_config, cohort, tmin, tmax)
def _load_variable_from_config(
var_name: str,
var_src_config: dict[str, dict[str, Any]],
cohort: Cohort,
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
) -> MultiSourceVariable:
"""Iterates through sources and gets variable from each source.
Args:
var_name (str): The name of the variable to load.
var_src_config (dict): The source configuration for the variable.
cohort (Cohort): The cohort to load the variable for.
tmin (TimeBoundColumn | None): Minimum time for extraction (default: None).
tmax (TimeBoundColumn | None): Maximum time for extraction (default: None).
Returns:
ExtractedVariable: The extracted variable.
"""
src_var_mapping: dict[str, Variable] = {}
# Load from other sources
for src, src_module in _get_src_module_mapping(
include_sources=var_src_config.keys()
).items():
if src in cohort.sources:
var_config = var_src_config[src]
SrcVariable = getattr(src_module, "Variable")
if "tmin" in var_config:
if tmin is not None:
warnings.warn(
f"For this variable tmin is already set to {var_config['tmin']}. tmin will be ignored.",
UserWarning,
)
else:
var_config.setdefault("tmin", tmin)
if "tmax" in var_config:
if tmax is not None:
warnings.warn(
f"For this variable tmax is already set to {var_config['tmax']}. tmax will be ignored.",
UserWarning,
)
else:
var_config.setdefault("tmax", tmax)
src_var_mapping.setdefault(
src, SrcVariable(var_name=var_name, **var_config)
)
logger.info(f"Variable {var_name} found in source {src}")
else:
logger.info(f"Source {src} not requested in cohort. Skipping...")
if not src_var_mapping:
raise ValueError(
f"Variable {var_name} not found in any source (looked in {cohort.sources.keys()})"
)
return MultiSourceVariable(var_name, src_var_mapping)