from __future__ import annotations
import copy
import textwrap
import warnings
import pandas as pd
import polars as pl
from corr_vars import logger
import corr_vars.utils as utils
from corr_vars.sources.var_loader import load_variable
from corr_vars.definitions.constants import DYN_COLUMNS, COL_ORDER, PRIMARY_KEYS
from corr_vars.definitions.typing import (
ExtractedVariable,
PolarsFrame,
)
from corr_vars.utils.time import (
TimeAnchor,
TimeWindow,
WindowOverlap,
WindowRelation,
compatible_with_time_window,
)
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from corr_vars.core import Cohort
from corr_vars.definitions.typing import (
TimeAnchorColumn,
RequirementsIterable,
RequirementsDict,
CleaningDict,
VariableCallable,
)
[docs]
class Variable:
"""Base class for all variables.
Args:
var_name (str): The variable name.
dynamic (bool): True if the variable is dynamic (time-series).
requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []).
tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta).
tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta).
py (Callable): The function to call to calculate the variable.
py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe.
cleaning (CleaningDict): Dictionary with cleaning rules for the variable.
Note:
tmin and tmax can be None when you create a Variable object, but must be set before extraction.
If you add the variable via cohort.add_variable(), it will be automatically set to the cohort's tmin and tmax.
This base class should not be used directly; use one of the subclasses instead. These are specified in the sources submodule.
Examples:
Basic cleaning configuration:
>>> cleaning = {
... "value": {
... "low": 10,
... "high": 80
... }
... }
Time constraints with relative offsets:
>>> # Extract data from 2 hours before to 6 hours after ICU admission
>>> tmin = ("icu_admission", "-2h")
>>> tmax = ("icu_admission", "+6h")
Variable with dependencies:
>>> # This variable requires other variables to be loaded first
>>> requires = ["blood_pressure_sys", "blood_pressure_dia"]
>>> # This variable requires other variables with fixed tmin/tmax to be loaded first
>>> requires = {
... "blood_pressure_sys_hospital": {
... "template": "blood_pressure_sys"
... "tmin": "hospital_admission",
... "tmax": "hospital_discharge"
... },
... "blood_pressure_dia_hospital": {
... "template": "blood_pressure_dia"
... "tmin": "hospital_admission",
... "tmax": "hospital_discharge"
... }
... }
"""
var_name: str
dynamic: bool
requirements: dict[str, RequirementsDict]
required_vars: dict[str, ExtractedVariable]
data: pl.DataFrame | None
time_window: TimeWindow
py: VariableCallable | None
py_ready_polars: bool
cleaning: CleaningDict | None
def __init__(
self,
var_name: str,
dynamic: bool,
tmin: TimeAnchorColumn | None,
tmax: TimeAnchorColumn | None,
requires: RequirementsIterable = [],
py: VariableCallable | None = None,
py_ready_polars: bool = False,
cleaning: CleaningDict | None = None,
) -> None:
# Metadata
self.var_name = var_name
self.dynamic = dynamic # True if the variable is dynamic (time-series)
# Tmin / Tmax
self.time_window = TimeWindow(tmin=tmin, tmax=tmax, user_specified=True)
# Requirements
self.requirements = (
requires
if isinstance(requires, dict)
else {r: {"template": r, "tmin": None, "tmax": None} for r in requires}
)
self.required_vars = {}
# Variable function
self.py = py
self.py_ready_polars = py_ready_polars
# Cleaning
self.cleaning = cleaning
# Data
self.data = None
[docs]
@classmethod
def from_time_window(cls, time_window: TimeWindow, *args, **kwargs) -> Variable:
var = cls(*args, **kwargs)
var.time_window = time_window
return var
def _get_required_vars(self, cohort: Cohort) -> None:
"""Load required variables for this variable.
Args:
cohort (Cohort): The cohort instance to load variables from.
Returns:
None: Sets self.required_vars with loaded Variable objects.
"""
if len(self.requirements) == 0:
logger.debug(f"No dependencies loaded for '{self.var_name}'.")
return
# Print dependency tree before extraction
logger.info(f"Loading dependencies for '{self.var_name}':")
utils.log_collection(logger=logger, collection=self.requirements)
for save_as, config in self.requirements.items():
var_name = config["template"]
if save_as in cohort.constant_vars:
if (config["tmin"] and self.time_window.user_specified_tmin) or (
config["tmax"] and self.time_window.user_specified_tmax
):
warnings.warn(
f"tmin/tmax are ignored for the constant variable {var_name}. "
"Please remove tmin/tmax.",
UserWarning,
)
self.required_vars[save_as] = ExtractedVariable(
var_name=var_name,
dynamic=False,
data=cohort._obs.clone().select(cohort.primary_key, var_name),
)
else:
var = load_variable(
var_name=var_name,
cohort=cohort,
time_window=TimeWindow(
config["tmin"] or self.time_window.tmin,
config["tmax"] or self.time_window.tmax,
),
)
data = var.extract(cohort)
self.required_vars[save_as] = ExtractedVariable(
var_name=var.var_name, dynamic=var.dynamic, data=data
)
def _call_var_function(self, cohort: Cohort) -> bool:
"""
Call the variable function if it exists.
Args:
cohort (Cohort)
Returns:
True if the function was called,
False otherwise.
Operations will be applied to Variable.data directly.
"""
if self.py is None:
logger.debug(f"Variable function not called for {self.var_name}")
return False
code = self.py
logger.debug(f"Calling Variable function {self.var_name}")
if self.py_ready_polars:
data = code(var=self, cohort=copy.deepcopy(cohort))
# Give type checker a hint
if not isinstance(data, pl.DataFrame):
raise TypeError("Variable function did not return a polars DataFrame")
self.data = data
else:
# Legacy pandas code path
logger.warning(
f"Converting {self.var_name} to pandas. This is slow, consider rewriting your variable function to accept polars dataframes and then set py_ready_polars=True"
)
# Ignore all type errors to keep the rest of the code clean.
# Only operate on copies of Cohort and Variable to avoid
# pandas Dataframes if an error occurs inside this code block.
# Convert variable data to pandas (None for complex variables)
pd_var = copy.deepcopy(self)
if pd_var.data is not None:
pd_var.data = utils.convert_to_pandas_df(pd_var.data) # type: ignore
# Convert obs data to pandas
pd_cohort = copy.deepcopy(cohort)
pd_cohort._obs = utils.convert_to_pandas_df(pd_cohort._obs) # type: ignore
# Convert required variables to pandas
for reqvar in pd_var.required_vars.values():
if reqvar.data is not None:
reqvar.data = utils.convert_to_pandas_df(reqvar.data) # type: ignore
data = code(var=pd_var, cohort=pd_cohort)
# Give type checker a hint
if not isinstance(data, pd.DataFrame):
raise TypeError("Variable function did not return a pandas DataFrame")
self.data = utils.convert_to_polars_df(data)
return True
def _add_time_window(self, cohort: Cohort, join_key: str | None = None) -> None:
"""Adds tmin / tmax / join key / primary key from cohort.obs to self.data.
This might duplicate data if the join_key (default: cohort.primary_key) is not unique in obs.
Args:
cohort (Cohort): The cohort instance containing time constraint information.
join_key (str | None): Whether to apply filtering even if case_tmin/case_tmax columns don't exist (default: cohort.primary_key if set to None).
Returns:
None: Modifies self.data to add tmin / tmax / join key / primary key columns.
"""
if self.data is None:
warnings.warn(
"_add_time_window was called before extraction, skipping.", UserWarning
)
return
obs = cohort._obs.clone()
# Remove NA tmin/tmax and tmin/tmax in year 9999, etc.
obs = utils.remove_invalid_time_window(obs, time_window=self.time_window)
# Return df with "case_tmin" and "case_tmin" columns
obs = utils.add_time_window_expr(
obs, time_window=self.time_window, aliases=("tmin", "tmax")
)
# Return df with "case_tmin" and "case_tmax" columns
join_key = join_key or cohort.primary_key
cb = utils.get_cb(
obs,
join_key=join_key,
primary_key=cohort.primary_key,
tmin_col="tmin",
tmax_col="tmax",
)
self.data = self.data.join(cb, on=join_key)
def _timefilter(
self,
time_window: TimeWindow | None = None,
relation: WindowRelation = WindowRelation.ANY_OVERLAP,
) -> None:
"""Apply time-based filtering to variable data based on tmin/tmax constraints."""
if self.data is None:
warnings.warn(
"_timefilter was called before extraction, skipping.", UserWarning
)
return
window = time_window or TimeWindow("case_tmin", "case_tmax")
tmin = window.tmin
tmax = window.tmax
# If these columns do not exist, we can assume that a filter is not intended for this variable
has_time_columns = compatible_with_time_window(self.data, window)
if not has_time_columns:
logger.warning(
f"Time-filter not applied to {self.var_name}, no {tmin} or {tmax} in data and not enforced"
)
return
start_col = "recordtime"
if "recordtime_end" in self.data.columns:
# Prepare data
end_col = "recordtime_end"
end_cleaned_col = "recordtime_end_clean"
self.data = (
self.data
# Ensure valid interval (end >= start)
# Will also fill missing in end_col with values from start_col
.with_columns(
pl.max_horizontal(start_col, end_col).alias(end_cleaned_col),
)
)
# Interval filter on recordtime and recordtime_end
record_window = TimeWindow(start_col, end_cleaned_col)
self.data = self.data.filter(
WindowOverlap.overlap_expr(
record=record_window,
window=window,
relation=relation,
)
)
# Remove end_cleaned column
self.data = self.data.drop(end_cleaned_col)
else:
# Simple filter on recordtime
record_anchor = TimeAnchor(start_col)
self.data = self.data.filter(
WindowOverlap.contains_expr(record_anchor, window)
)
# Remove tmin / tmax columns
if tmin.column:
self.data = self.data.drop(tmin.column)
if tmax.column:
self.data = self.data.drop(tmax.column)
logger.info(f"Time-filter selected {len(self.data)} rows, column names:")
logger.info(utils.pretty_join(self.data.columns, indent="\t"))
def _apply_cleaning(self) -> bool:
"""Apply cleaning rules to filter out invalid values.
Returns:
True if the data was cleaned,
False otherwise.
Operations will be applied to Variable.data directly.
"""
if self.data is None:
warnings.warn(
"_apply_cleaning was called before extraction, skipping.", UserWarning
)
return False
if not self.cleaning:
return False
for col, bounds in self.cleaning.items():
if (
col == "value"
and "value" not in self.data.columns
and self.var_name in self.data.columns
):
col = self.var_name
if "low" in bounds:
self.data = self.data.filter(pl.col(col).ge(bounds["low"]))
if "high" in bounds:
self.data = self.data.filter(pl.col(col).le(bounds["high"]))
return True
def _add_relative_times(self, cohort: Cohort) -> None:
"""Add relative time columns to dynamic variable data.
Args:
cohort (Cohort): The cohort instance containing reference time information.
Returns:
None: Modifies self.data to include recordtime_relative and recordtime_end_relative columns.
"""
if self.data is None:
warnings.warn(
"_add_relative_times was called before extraction, skipping.",
UserWarning,
)
return
# Join with cohort data to get the reference time
self.data = utils.convert_to_polars_df(self.data)
self.data = self.data.join(
cohort._obs.select(cohort.primary_key, cohort.t_min),
on=cohort.primary_key,
how="left",
)
def relative_column(
df: pl.DataFrame, time_col: str, reference_col: str
) -> pl.DataFrame:
return df.with_columns(
pl.col(time_col)
.sub(pl.col(reference_col))
.dt.total_seconds()
.alias(f"{time_col}_relative")
)
if "recordtime" in self.data.columns:
# Calculate relative times from icu_admission in seconds
self.data = self.data.pipe(
relative_column, time_col="recordtime", reference_col=cohort.t_min
)
if (
"recordtime_end" in self.data.columns
and not self.data["recordtime_end"].is_null().all()
):
self.data = self.data.pipe(
relative_column, time_col="recordtime_end", reference_col=cohort.t_min
)
# Drop the reference time column
self.data = self.data.drop(cohort.t_min)
def _unify_and_order_columns(self, primary_key: str) -> None:
"""Order columns according to predefined COL_ORDER and add missing columns with null values.
Args:
primary_key (str): The primary key column name.
Returns:
None: Modifies self.data to have standardized column order.
"""
if self.data is None:
warnings.warn(
"_unify_and_order_columns was called before extraction, skipping.",
UserWarning,
)
return
assert (
primary_key in PRIMARY_KEYS
), f"Primary key {primary_key} must be one of {PRIMARY_KEYS}."
assert (
primary_key in self.data.columns
), f"Primary key {primary_key} not in data columns."
# Add missing dynamic columns (no keys) with null values
missing_cols = (col for col in DYN_COLUMNS if col not in self.data.columns)
if missing_cols:
self.data = self.data.with_columns(
pl.lit(None).alias(col) for col in missing_cols
)
# Order columns according to COL_ORDER, then add remaining columns
data_cols = set(self.data.columns)
ordered_cols = [col for col in COL_ORDER if col in data_cols]
remaining_cols = [col for col in data_cols if col not in COL_ORDER]
if len(remaining_cols) > 0:
warnings.warn(
"🚨 DEPRECATION ALERT 🚨\n"
"Unknown columns found.\n"
"- Please update the variable definition to move these to attributes.\n"
"- Unknown columns are no longer supported and might be deprecated soon!\n"
"- Columns: " + ", ".join(remaining_cols),
UserWarning,
)
self.data = self.data.select(ordered_cols + remaining_cols)
# Uncomment after user warning
# self.data = self.data.select(ordered_cols)
# System methods
def __repr__(self) -> str:
return self.__str__()
def __str__(self) -> str:
return textwrap.dedent(
f"""\
Variable: {self.var_name} ({self.__class__.__name__}, {'dynamic' if self.dynamic else 'static'})
├── time_window: {self.time_window}
└── data: {'Not extracted' if self.data is None else self.data.shape}
"""
).strip()
def __getstate__(self) -> dict[str, Any]:
return self.__dict__
def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)
[docs]
class NativeVariable(Variable):
"""Extended Variable class for native variables from data sources.
Args:
var_name (str): The variable name.
dynamic (bool): True if the variable is dynamic (time-series) (default: True).
requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []).
tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None).
tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None).
py (Callable): The function to call to calculate the variable (default: None).
py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe (default: False).
cleaning (CleaningDict): Dictionary with cleaning rules for the variable (default: None).
allow_caching (bool): Whether to allow caching of this variable (default: True).
"""
def __init__(
self,
# Parent class parameters
var_name: str,
dynamic: bool,
requires: RequirementsIterable = [],
tmin: TimeAnchorColumn | None = None,
tmax: TimeAnchorColumn | None = None,
py: VariableCallable | None = None,
py_ready_polars: bool = False,
cleaning: CleaningDict | None = None,
# This class parameters
allow_caching: bool = True,
) -> None:
super().__init__(
var_name,
dynamic=dynamic,
requires=requires,
tmin=tmin,
tmax=tmax,
py=py,
py_ready_polars=py_ready_polars,
cleaning=cleaning,
)
self.allow_caching = allow_caching
[docs]
def build_attributes(self, data: PolarsFrame) -> PolarsFrame:
"""Combines columns prefixed by attributes_ into a struct column called attributes"""
if not self.dynamic:
logger.debug("Skipping attribtues for non-dynamic variables")
return data
attr_cols = [
col
for col in data.collect_schema().names()
if col.startswith("attributes_")
]
if attr_cols:
struct_expr = {col[len("attributes_") :]: pl.col(col) for col in attr_cols}
# Mypy does not recognise the kwargs correctly as an overload
data = data.with_columns(pl.struct(**struct_expr).alias("attributes")) # type: ignore[call-overload]
data = data.drop(attr_cols)
return data
# def on_admission(self, select: str = "!first value") -> NativeStatic:
# """Create a new NativeStatic variable based on the current variable that extracts the value on admission.
# Args:
# select: Select clause specifying aggregation function and columns. Defaults to ``!first value``.
# Returns:
# NativeStatic
# Examples:
# >>> # Return the first value
# >>> var_adm = variable.on_admission()
# >>> cohort.add_variable(var_adm)
# >>> # Be more specific with your selection
# >>> var_adm = variable.on_admission("!closest(hospital_admission,0,2h) value")
# >>> cohort.add_variable(var_adm)
# """
# if not self.dynamic:
# raise ValueError("on_admission is only supported for dynamic variables")
# return NativeStatic(
# var_name=f"{self.var_name}_adm",
# select=select,
# base_var=self.var_name,
# tmin=self.tmin,
# tmax=self.tmax,
# )