from __future__ import annotations
import importlib
import os
import pkgutil
import polars as pl
from corr_vars import logger
import corr_vars.sources as sources
import corr_vars.utils as utils
from collections.abc import Iterable, Mapping
from types import ModuleType
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from corr_vars.core.cohort import Cohort
from corr_vars.core.variable import Variable, TimeWindow, TimeAnchor
from corr_vars.definitions import VariableProtocol
# Only get submodules that have a __init__.py file -> real sources
SOURCES = [
name
for _, name, ispkg in pkgutil.iter_modules(sources.__path__)
if ispkg
and os.path.isdir(os.path.join(os.path.dirname(sources.__file__), name))
and os.path.isfile(
os.path.join(os.path.dirname(sources.__file__), name, "__init__.py")
)
]
class MultiSourceVariable:
"""Concatenate variables from multiple sources.
Args:
var_name (str): The name of the variable to concatenate.
list[Variable]: The list of variables to concatenate.
"""
var_name: str
variables: Mapping[str, VariableProtocol]
data: pl.DataFrame | None
_dynamic: bool
_time_window: TimeWindow
def __init__(
self, var_name: str, variables: Mapping[str, VariableProtocol]
) -> None:
self.var_name = var_name
self.variables = variables
self._sanity_check()
def _sanity_check(self) -> None:
if not self.variables:
raise ValueError("No variables to concatenate")
# Assign dynamic, tmin and tmax after veryfing variables is not empty
first_variable = next(iter(self.variables.values()))
self._dynamic = first_variable.dynamic
self._time_window = first_variable.time_window
for var in self.variables.values():
assert (
var.var_name == self.var_name
), f"Variable name mismatch: {var.var_name} != {self.var_name}"
assert (
var.dynamic == self._dynamic
), f"Dynamic status mismatch for {self.var_name}: {var.dynamic} != {self._dynamic}"
assert (
var.time_window == self._time_window
), f"Timebound mismatch for {self.var_name}: {var.time_window} != { self._time_window}"
def extract(self, cohort: Cohort, with_source_column: bool = False) -> pl.DataFrame:
var_data = []
for src, var in self.variables.items():
extracted = var.extract(cohort)
if with_source_column:
extracted = extracted.insert_column(0, pl.lit(src).alias("data_source"))
var_data.append(extracted)
logger.info(f"SUCCESS: Extracted {self.var_name} from {src}")
combined_data = pl.concat(var_data, how="diagonal_relaxed")
self.data = combined_data
return combined_data
@property
def dynamic(self) -> bool:
return self._dynamic
@dynamic.setter
def dynamic(self, _: bool) -> bool:
raise AttributeError("Cannot set dynamic")
@property
def time_window(self) -> TimeWindow:
return self._time_window
@time_window.setter
def time_window(self, time_window: TimeWindow) -> None:
if self.extracted:
raise ValueError("Cannot set time_window after extraction")
self._time_window = time_window
for var in self.variables.values():
var.time_window = time_window
@property
def extracted(self) -> bool:
return self.data is not None
def __repr__(self) -> str:
return (
f"MultiSourceVariable(var_name='{self.var_name}', "
f"dynamic={self._dynamic}, "
f"time_window={self._time_window}, "
f"sources={list(self.variables.keys())}, "
f"extracted={self.extracted})"
)
def __str__(self) -> str:
return self.__repr__()
def _get_src_module(source: str, submodule: str | None = None) -> ModuleType:
sourced_module = f"corr_vars.sources.{source}"
return importlib.import_module(
f"{sourced_module}.{submodule}" if submodule else sourced_module
)
def _get_src_module_mapping(
include_sources: Iterable[str] | None = None, submodule: str | None = None
) -> dict[str, ModuleType]:
if include_sources is None:
include_sources = SOURCES
return {
src: _get_src_module(source=src, submodule=submodule) for src in include_sources
}
def _get_overrides(var_name: str, cohort: Cohort) -> dict[str, dict[str, Any]]:
"""Get the overrides for a variable.
Args:
var_name (str): Name of the variable to get overrides for.
cohort (Cohort): The cohort to get overrides for.
Returns:
overrides (dict[str, dict[str, Any]]): Dictionary of overrides.
"""
override = {}
# Project vars
if hasattr(cohort, "project_vars"):
# Legacy mode -> Update all existing sources
if var_name in cohort.project_vars:
override = {src: cohort.project_vars[var_name] for src in cohort.sources}
else:
# New implementation with source specification
override = {
src: cohort.project_vars[var_name]
for src in cohort.sources
if var_name in cohort.project_vars.get(src, dict())
}
return override
def load_default_variables(
cohort: Cohort, time_window: TimeWindow
) -> list[MultiSourceVariable]:
"""Load the default variables for a given observation level.
Args:
cohort (Cohort): The Cohort object to load variables for.
time_window (TimeWindow): Minimum and maximum time for extraction.
Returns:
list[MultiSourceVariable]: List of MultiSourceVariable objects configured for the observation level.
"""
variable_source_mapping: dict[str, list[str]] = {}
for src, src_module in _get_src_module_mapping().items():
src_mapping_module = getattr(src_module, "mapping")
src_VARS = getattr(src_mapping_module, "VARS")
if not (
"corr_defaults" in src_VARS and "default_vars" in src_VARS["corr_defaults"]
):
continue
default_vars = src_VARS["corr_defaults"]["default_vars"]
global_defaults = default_vars.get("global", [])
obs_level_defaults = default_vars.get(cohort.obs_level.lower_name, [])
all_defaults = global_defaults + obs_level_defaults
for var_name in all_defaults:
variable_source_mapping.setdefault(var_name, [])
variable_source_mapping[var_name].append(src)
variables = []
for var_name, include_sources in variable_source_mapping.items():
override = _get_overrides(var_name, cohort)
var_src_config = _get_variable_src_config(
var_name, include_sources=include_sources, override=override
)
variable = _load_variable_from_config(
var_name, var_src_config, cohort, time_window
)
variables.append(variable)
return variables
def _get_variable_src_config(
var_name: str,
include_sources: Iterable[str] | None = None,
override: dict[str, dict[str, Any]] | None = None,
) -> dict[str, dict[str, object]]:
"""Get the source configuration for a variable.
Args:
var_name (str): Name of the variable to get configuration for.
include_sources: (list[str] | None): The sources to consider for the config. Will include all sources if set to None.
override (dict[str, dict[str, Any]] | None): A dictionary used to update the configuration dictionary.
Returns:
dict: Configuration dictionary for the variable.
Raises:
ValueError: If variable is not found in configuration.
"""
var_dict: dict[str, dict[str, Any]] = {}
for src, src_mapping_module in _get_src_module_mapping(
include_sources=include_sources, submodule="mapping"
).items():
src_VARS = getattr(src_mapping_module, "VARS")
src_VARS_variables = src_VARS["variables"]
src_VARCODE = getattr(src_mapping_module, "variables")
if var_name not in src_VARS_variables:
continue
src_vardict = src_VARS_variables[var_name].copy()
if hasattr(src_VARCODE, var_name):
src_vardict["py"] = getattr(src_VARCODE, var_name)
var_dict[src] = src_vardict
if not var_dict:
raise ValueError(f"Variable {var_name} not found")
logger.debug(f"Loaded config for {var_name}:")
utils.log_dict(logger, var_dict, level=utils.logging.DEBUG, json_indent=2)
# Project vars
if override:
for src in var_dict:
var_dict[src].update(override.get(src, dict()))
return var_dict
[docs]
def load_variable(
var_name: str, cohort: Cohort, time_window: TimeWindow
) -> MultiSourceVariable:
"""Iterates through sources and gets variable from each source.
Args:
var_name (str): The name of the variable to load.
cohort (Cohort): The cohort to load the variable for.
tmin (TimeBoundColumn | None): Minimum time for extraction (default: None).
tmax (TimeBoundColumn | None): Maximum time for extraction (default: None).
Returns:
MultiSourceVariable: The loaded variable ready for extraction.
"""
override = _get_overrides(var_name, cohort)
var_src_config = _get_variable_src_config(
var_name, include_sources=cohort.sources.keys(), override=override
)
return _load_variable_from_config(var_name, var_src_config, cohort, time_window)
def _load_variable_from_config(
var_name: str,
var_src_config: dict[str, dict[str, Any]],
cohort: Cohort,
time_window: TimeWindow,
) -> MultiSourceVariable:
"""Iterates through sources and gets variable from each source.
Args:
var_name (str): The name of the variable to load.
var_src_config (dict[str, dict[str, Any]]): The source configuration for the variable.
cohort (Cohort): The cohort to load the variable for.
time_window (TimeWindow): Time window for extraction.
Returns:
MultiSourceVariable: The loaded variable.
"""
src_var_mapping: dict[str, Variable] = {}
# Load from other sources
for src, src_module in _get_src_module_mapping(
include_sources=var_src_config.keys()
).items():
if src in cohort.sources:
var_config = var_src_config[src]
SrcVariable = getattr(src_module, "Variable")
if "tmin" in var_config:
logger.warning(
f"For this variable tmin is already frozen to {var_config['tmin']}. Any other tmin will be ignored.",
)
tmin = var_config.pop("tmin")
time_window.tmin = TimeAnchor(tmin)
if "tmax" in var_config:
logger.warning(
f"For this variable tmax is already frozen to {var_config['tmax']}. Any other tmax will be ignored.",
)
tmax = var_config.pop("tmax")
time_window.tmax = TimeAnchor(tmax)
src_var_mapping.setdefault(
src,
SrcVariable.from_time_window(
var_name=var_name, time_window=time_window, **var_config
),
)
logger.info(f"Variable {var_name} found in source {src}")
else:
logger.info(f"Source {src} not requested in cohort. Skipping...")
if not src_var_mapping:
raise ValueError(
f"Variable {var_name} not found in any source (looked in {cohort.sources.keys()})"
)
return MultiSourceVariable(var_name, src_var_mapping)