from __future__ import annotations
import textwrap
import warnings
import polars as pl
from corr_vars import logger, __
import corr_vars.utils as utils
from corr_vars.sources import guess_variable_source
from corr_vars.definitions.typing import ExtractedVariable
from corr_vars.core.steps import CleaningStep, PyFuncStep, TimeFilterStep
from corr_vars.utils.time import TimeAnchor, TimeWindow
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from corr_vars.core import Cohort
from corr_vars.definitions.typing import (
TimeAnchorColumn,
RequirementsIterable,
RequirementsDict,
CleaningDict,
VariableCallable,
)
[docs]
class Variable:
"""Base class for all variables.
Args:
var_name (str): The variable name.
dynamic (bool): True if the variable is dynamic (time-series).
requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []).
tmin (TimeAnchorColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta).
tmax (TimeAnchorColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta).
py (Callable): The function to call to calculate the variable.
py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe.
cleaning (CleaningDict): Dictionary with cleaning rules for the variable.
Note:
tmin and tmax can be None when you create a Variable object, but must be set before extraction.
If you add the variable via cohort.add_variable(), it will be automatically set to the cohort's tmin and tmax.
This base class should not be used directly; use one of the subclasses instead. These are specified in the sources submodule.
Examples:
Basic cleaning configuration:
>>> cleaning = {
... "value": {
... "low": 10,
... "high": 80
... }
... }
Time constraints with relative offsets:
>>> # Extract data from 2 hours before to 6 hours after ICU admission
>>> tmin = ("icu_admission", "-2h")
>>> tmax = ("icu_admission", "+6h")
Variable with dependencies:
>>> # This variable requires other variables to be loaded first
>>> requires = ["blood_pressure_sys", "blood_pressure_dia"]
>>> # This variable requires other variables with fixed tmin/tmax to be loaded first
>>> requires = {
... "blood_pressure_sys_hospital": {
... "template": "blood_pressure_sys"
... "tmin": "hospital_admission",
... "tmax": "hospital_discharge"
... },
... "blood_pressure_dia_hospital": {
... "template": "blood_pressure_dia"
... "tmin": "hospital_admission",
... "tmax": "hospital_discharge"
... }
... }
"""
var_name: str
dynamic: bool
data: pl.DataFrame | None
requirements: dict[str, RequirementsDict]
required_vars: dict[str, ExtractedVariable]
py: VariableCallable | None
py_ready_polars: bool
time_window: TimeWindow
cleaning: CleaningDict | None
def __init__(
self,
var_name: str,
dynamic: bool,
tmin: TimeAnchorColumn | TimeAnchor | None,
tmax: TimeAnchorColumn | TimeAnchor | None,
requires: RequirementsIterable = [],
py: VariableCallable | None = None,
py_ready_polars: bool = False,
cleaning: CleaningDict | None = None,
) -> None:
self.var_name = var_name
self.dynamic = dynamic
self.time_window = TimeWindow(tmin=tmin, tmax=tmax)
self.requirements = (
requires
if isinstance(requires, dict)
else {r: {"template": r, "tmin": None, "tmax": None} for r in requires}
)
self.required_vars = {}
self.py = py
self.py_ready_polars = py_ready_polars
self.cleaning = cleaning
self.data = None
@property
def guessed_source(self) -> str | None:
return guess_variable_source(self)
[docs]
@classmethod
def from_time_window(cls, time_window: TimeWindow, *args, **kwargs) -> Variable:
kwargs["tmin"] = time_window.tmin
kwargs["tmax"] = time_window.tmax
return cls(*args, **kwargs)
def _get_required_vars(self, cohort: Cohort) -> None:
"""Load required variables for this variable.
Args:
cohort (Cohort): The cohort instance to load variables from.
Returns:
None: Sets self.required_vars with loaded Variable objects.
"""
if len(self.requirements) == 0:
logger.debug(
__("No dependencies loaded for '{var_name}'.", var_name=self.var_name)
)
return
# Print dependency tree before extraction
logger.info(
__("Loading dependencies for '{var_name}':", var_name=self.var_name)
)
utils.log_collection(logger=logger, collection=self.requirements)
for save_as, config in self.requirements.items():
var_name = config["template"]
include_sources = [self.guessed_source] if self.guessed_source else None
var = cohort.load_variable(
variable=(
var_name,
TimeWindow(
config["tmin"] or self.time_window.tmin,
config["tmax"] or self.time_window.tmax,
),
),
include_sources=include_sources,
)
data = var.extract(cohort)
self.required_vars[save_as] = ExtractedVariable(
var_name=var.var_name, dynamic=var.dynamic, data=data
)
def _call_var_function(self, cohort: Cohort) -> bool:
"""Call the variable function if it exists.
Args:
cohort (Cohort)
Returns:
True if the function was called, False otherwise.
Operations are applied to Variable.data directly.
"""
if self.py is None:
logger.debug(
__(
"Variable function not called for {var_name}",
var_name=self.var_name,
)
)
return False
py_func_step = PyFuncStep(self.py, self, py_ready_polars=self.py_ready_polars)
self.data = py_func_step.call(
self.data, cohort, required_vars=self.required_vars
)
return True
def _add_time_window(self, cohort: Cohort, join_key: str | None = None) -> None:
"""Adds tmin / tmax / join key / primary key from cohort.obs to self.data.
This might duplicate data if the join_key (default: cohort.primary_key) is not unique in obs.
Args:
cohort (Cohort): The cohort instance containing time constraint information.
join_key (str | None): Whether to apply filtering even if case_tmin/case_tmax columns don't exist (default: cohort.primary_key if set to None).
Returns:
None: Modifies self.data to add tmin / tmax / join key / primary key columns.
"""
if self.data is None:
warnings.warn(
"_add_time_window was called before extraction, skipping.", UserWarning
)
return
time_filter_step = TimeFilterStep(self.time_window, join_key=join_key)
self.data = time_filter_step.add_window_columns(self.data, cohort)
def _timefilter(self) -> None:
"""Apply time-based filtering to variable data based on tmin/tmax constraints."""
if self.data is None:
warnings.warn(
"_timefilter was called before extraction, skipping.", UserWarning
)
return
time_filter_step = TimeFilterStep(self.time_window)
self.data = time_filter_step.filter(self.data)
def _apply_cleaning(self) -> bool:
"""Apply cleaning rules to filter out invalid values.
Returns:
True if the data was cleaned,
False otherwise.
Operations will be applied to Variable.data directly.
"""
if self.data is None:
warnings.warn(
"_apply_cleaning was called before extraction, skipping.", UserWarning
)
return False
cleaning_step = CleaningStep(self.cleaning, aliases={"value": self.var_name})
self.data, applied = cleaning_step.clean(self.data)
return applied
# System methods
def __repr__(self) -> str:
return self.__str__()
def __str__(self) -> str:
return textwrap.dedent(
f"""\
Variable: {self.var_name} ({self.__class__.__name__}, {'dynamic' if self.dynamic else 'static'})
├── time_window: {self.time_window}
└── data: {'Not extracted' if self.data is None else self.data.shape}
"""
).strip()
def __getstate__(self) -> dict[str, Any]:
return self.__dict__
def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)