from __future__ import annotations
import copy
import textwrap
import warnings
import pandas as pd
import polars as pl
from corr_vars import logger
import corr_vars.utils as utils
from corr_vars.sources.var_loader import load_variable
from typing import Any, TYPE_CHECKING
from corr_vars.definitions.constants import DYN_COLUMNS, COL_ORDER, PRIMARY_KEYS
from corr_vars.definitions.typing import ExtractedVariable, PolarsFrame
if TYPE_CHECKING:
from corr_vars.core import Cohort
from corr_vars.definitions.typing import (
TimeBoundColumn,
RequirementsIterable,
RequirementsDict,
CleaningDict,
VariableCallable,
)
[docs]
class Variable:
"""Base class for all variables.
Args:
var_name (str): The variable name.
dynamic (bool): True if the variable is dynamic (time-series).
requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []).
tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta).
tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta).
py (Callable): The function to call to calculate the variable.
py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe.
cleaning (CleaningDict): Dictionary with cleaning rules for the variable.
Note:
tmin and tmax can be None when you create a Variable object, but must be set before extraction.
If you add the variable via cohort.add_variable(), it will be automatically set to the cohort's tmin and tmax.
This base class should not be used directly; use one of the subclasses instead. These are specified in the sources submodule.
Examples:
Basic cleaning configuration:
>>> cleaning = {
... "value": {
... "low": 10,
... "high": 80
... }
... }
Time constraints with relative offsets:
>>> # Extract data from 2 hours before to 6 hours after ICU admission
>>> tmin = ("icu_admission", "-2h")
>>> tmax = ("icu_admission", "+6h")
Variable with dependencies:
>>> # This variable requires other variables to be loaded first
>>> requires = ["blood_pressure_sys", "blood_pressure_dia"]
>>> # This variable requires other variables with fixed tmin/tmax to be loaded first
>>> requires = {
... "blood_pressure_sys_hospital": {
... "template": "blood_pressure_sys"
... "tmin": "hospital_admission",
... "tmax": "hospital_discharge"
... },
... "blood_pressure_dia_hospital": {
... "template": "blood_pressure_dia"
... "tmin": "hospital_admission",
... "tmax": "hospital_discharge"
... }
... }
"""
var_name: str
dynamic: bool
requires: RequirementsIterable
required_vars: dict[str, ExtractedVariable]
data: pl.DataFrame | None
tmin: TimeBoundColumn | None
tmax: TimeBoundColumn | None
py: VariableCallable | None
py_ready_polars: bool
cleaning: CleaningDict | None
def __init__(
self,
var_name: str,
dynamic: bool,
requires: RequirementsIterable = [],
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
py: VariableCallable | None = None,
py_ready_polars: bool = False,
cleaning: CleaningDict | None = None,
) -> None:
self.var_name = var_name
self.dynamic = dynamic # True if the variable is dynamic (time-series)
self.requires = requires
self.required_vars = {}
self.data = None
self.tmin = tmin
self.tmax = tmax
self.py = py
self.py_ready_polars = py_ready_polars
self.cleaning = cleaning
def _get_required_vars(self, cohort: Cohort) -> None:
"""Load required variables for this variable.
Args:
cohort (Cohort): The cohort instance to load variables from.
Returns:
None: Sets self.required_vars with loaded Variable objects.
"""
if len(self.requires) == 0:
return
# Print dependency tree before extraction
logger.info(f"Loading dependencies for {self.var_name}:")
utils.log_collection(logger=logger, collection=self.requires)
for var_name in self.requires:
overrides: RequirementsDict = {}
if isinstance(self.requires, dict):
overrides = copy.deepcopy(self.requires[var_name])
if var_name in cohort.constant_vars:
if overrides.get("tmin") or overrides.get("tmax"):
warnings.warn(
f"tmin/tmax are ignored for the constant variable {var_name}. Please remove tmin/tmax.",
UserWarning,
)
self.required_vars[var_name] = ExtractedVariable(
var_name=var_name,
dynamic=False,
data=cohort._obs.clone().select(cohort.primary_key, var_name),
)
else:
overrides.setdefault("template", var_name)
overrides.setdefault("tmin", self.tmin)
overrides.setdefault("tmax", self.tmax)
var = load_variable(
var_name=overrides["template"],
cohort=cohort,
tmin=overrides["tmin"],
tmax=overrides["tmax"],
)
data = var.extract(cohort)
self.required_vars[var_name] = ExtractedVariable(
var_name=var.var_name, dynamic=var.dynamic, data=data
)
def _call_var_function(self, cohort: Cohort, convert_polars: bool = False) -> bool:
"""
Call the variable function if it exists.
Args:
cohort (Cohort)
Returns:
True if the function was called,
False otherwise.
Operations will be applied to Variable.data directly.
"""
if self.py is None:
return False
code = self.py
logger.debug(f"Calling Variable function {self.var_name}")
if self.py_ready_polars:
self.data = code(var=self, cohort=copy.deepcopy(cohort))
# Give type checker a hint
assert isinstance(
self.data, pl.DataFrame
), "Variable function did not return a polars DataFrame"
if not convert_polars:
self.data = utils.convert_to_pandas_df(self.data)
else:
# Legacy pandas code path
logger.warning(
f"Converting {self.var_name} to pandas. This is slow, consider rewriting your variable function to accept polars dataframes and then set py_ready_polars=True"
)
if self.data is not None:
self.data = utils.convert_to_pandas_df(self.data)
pd_cohort = copy.deepcopy(cohort)
pd_cohort._obs = pd_cohort._obs.to_pandas()
# Convert required variables to pandas
# These do not need to be converted back as they are only used here
for reqvar in self.required_vars.values():
if reqvar.data is not None:
reqvar.data = reqvar.data.to_pandas()
self.data = code(var=self, cohort=pd_cohort)
# Give type checker a hint
assert isinstance(
self.data, pd.DataFrame
), "Variable function did not return a pandas DataFrame"
if convert_polars:
self.data = utils.convert_to_polars_df(self.data)
return True
def _timefilter(
self, cohort: Cohort, always: bool = False, join_key: str = "case_id"
) -> None:
"""Apply time-based filtering to variable data based on tmin/tmax constraints.
Args:
cohort (Cohort): The cohort instance containing time constraint information.
always (bool): Whether to apply filtering even if case_tmin/case_tmax columns don't exist (default: False).
Returns:
None: Modifies self.data to filter rows within time constraints.
"""
if self.data is None:
warnings.warn(
"Timefilter was called before extraction, skipping.", UserWarning
)
return
if always:
df = cohort._obs.clone()
df = utils.pl_parse_time_args(df, tmin=self.tmin, tmax=self.tmax)
cb = utils.get_cb(
df, join_key=join_key, primary_key=cohort.primary_key
) # Return df with case_tmin and case_tmax columns
self.data = self.data.join(cb, on=join_key)
# If these columns do not exist, we can assume that a filter is not intended for this variable
has_columns = (
"case_tmin" in self.data.columns and "case_tmax" in self.data.columns
)
if has_columns:
if "recordtime_end" in self.data.columns:
# Filter for any rows overlapping with case_tmin and case_tmax
self.data = (
self.data
# Fill missings
.with_columns(
pl.when(pl.col("recordtime_end").is_null())
.then(pl.col("recordtime"))
.otherwise(pl.col("recordtime_end"))
.alias("recordtime_end_clean"),
)
# Ensure valid interval (end >= start), swap if necessary
.with_columns(
pl.when(pl.col("recordtime_end_clean") < pl.col("recordtime"))
.then(pl.col("recordtime"))
.otherwise(pl.col("recordtime_end_clean"))
.alias("recordtime_end_clean"),
)
# Calculate overlap
.with_columns(
pl.max_horizontal(["recordtime", "case_tmin"]).alias(
"overlap_start"
),
pl.min_horizontal(["recordtime_end_clean", "case_tmax"]).alias(
"overlap_end"
),
)
# Filter to only intervals that have some overlap
.filter(pl.col("overlap_start") <= pl.col("overlap_end"))
# Drop intermediate columns
.drop(
"overlap_start",
"overlap_end",
"recordtime_end_clean",
"case_tmin",
"case_tmax",
)
)
else:
# Simple filter on recordtime
self.data = self.data.filter(
pl.col("recordtime").is_between("case_tmin", "case_tmax")
).drop("case_tmin", "case_tmax")
logger.info(f"Time-filter selected {len(self.data)} rows, column names:")
logger.info(utils.pretty_join(self.data.columns, indent="\t"))
else:
logger.warning(
f"Time-filter not applied to {self.var_name}, no case_tmin or case_tmax in data and not enforced"
)
def _apply_cleaning(self) -> None:
"""Apply cleaning rules to filter out invalid values.
Returns:
pl.DataFrame: The cleaned data.
"""
if self.data is None:
warnings.warn(
"Cleaning was called before extraction, skipping.", UserWarning
)
return
if self.cleaning:
for col, bounds in self.cleaning.items():
if (
col == "value"
and "value" not in self.data.columns
and self.var_name in self.data.columns
):
col = self.var_name
if "low" in bounds:
self.data = self.data.filter(pl.col(col).ge(bounds["low"]))
if "high" in bounds:
self.data = self.data.filter(pl.col(col).le(bounds["high"]))
def _add_relative_times(self, cohort: Cohort) -> None:
"""Add relative time columns to dynamic variable data.
Args:
cohort (Cohort): The cohort instance containing reference time information.
Returns:
None: Modifies self.data to include recordtime_relative and recordtime_end_relative columns.
"""
if self.data is None:
warnings.warn(
"Relative time was called before extraction, skipping.", UserWarning
)
return
# Join with cohort data to get the reference time
self.data = utils.convert_to_polars_df(self.data)
self.data = self.data.join(
cohort._obs.select(cohort.primary_key, cohort.t_min),
on=cohort.primary_key,
how="left",
)
def relative_column(
df: pl.DataFrame, time_col: str, reference_col: str
) -> pl.DataFrame:
return df.with_columns(
pl.col(time_col)
.sub(pl.col(reference_col))
.dt.total_seconds()
.alias(f"{time_col}_relative")
)
if "recordtime" in self.data.columns:
# Calculate relative times from icu_admission in seconds
self.data = self.data.pipe(
relative_column, time_col="recordtime", reference_col=cohort.t_min
)
if (
"recordtime_end" in self.data.columns
and not self.data.select(pl.col("recordtime_end").is_null().all()).item()
):
self.data = self.data.pipe(
relative_column, time_col="recordtime_end", reference_col=cohort.t_min
)
# Drop the reference time column
self.data = self.data.drop(cohort.t_min)
def _unify_and_order_columns(self, primary_key: str) -> None:
"""Order columns according to predefined COL_ORDER and add missing columns with null values.
Args:
primary_key (str): The primary key column name.
Returns:
None: Modifies self.data to have standardized column order.
"""
if self.data is None:
warnings.warn(
"Unify and order was called before extraction, skipping.", UserWarning
)
return
assert (
primary_key in PRIMARY_KEYS
), f"Primary key {primary_key} must be one of {PRIMARY_KEYS}."
assert (
primary_key in self.data.columns
), f"Primary key {primary_key} not in data columns."
# Add missing dynamic columns (no keys) with null values
missing_cols = (col for col in DYN_COLUMNS if col not in self.data.columns)
if missing_cols:
self.data = self.data.with_columns(
pl.lit(None).alias(col) for col in missing_cols
)
# Order columns according to COL_ORDER, then add remaining columns
data_cols = set(self.data.columns)
ordered_cols = [col for col in COL_ORDER if col in data_cols]
remaining_cols = [col for col in data_cols if col not in COL_ORDER]
if len(remaining_cols) > 0:
logger.warning(
"🚨 DEPRECATION ALERT 🚨 \nUnknown columns found.\n - Please update the variable definition to move these to attributes. \n - Unknown columns are no longer supported and might be deprecated soon! \n - Columns:"
+ ", ".join(remaining_cols)
)
self.data = self.data.select(ordered_cols + remaining_cols)
# Uncomment after user warning
# self.data = self.data.select(ordered_cols)
# System methods
def __repr__(self) -> str:
return self.__str__()
def __str__(self) -> str:
return textwrap.dedent(
f"""\
Variable: {self.var_name} ({self.__class__.__name__}, {'dynamic' if self.dynamic else 'static'})
├── tmin: {self.tmin}, tmax: {self.tmax}
└── data: {'Not extracted' if self.data is None else self.data.shape}
"""
).strip()
def __getstate__(self) -> dict[str, Any]:
return self.__dict__
def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)
[docs]
class NativeVariable(Variable):
"""Extended Variable class for native variables from data sources.
Args:
var_name (str): The variable name.
dynamic (bool): True if the variable is dynamic (time-series) (default: True).
requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []).
tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None).
tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None).
py (Callable): The function to call to calculate the variable (default: None).
py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe (default: False).
cleaning (CleaningDict): Dictionary with cleaning rules for the variable (default: None).
allow_caching (bool): Whether to allow caching of this variable (default: True).
"""
def __init__(
self,
# Parent class parameters
var_name: str,
dynamic: bool,
requires: RequirementsIterable = [],
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
py: VariableCallable | None = None,
py_ready_polars: bool = False,
cleaning: CleaningDict | None = None,
# This class parameters
allow_caching: bool = True,
) -> None:
super().__init__(
var_name,
dynamic=dynamic,
requires=requires,
tmin=tmin,
tmax=tmax,
py=py,
py_ready_polars=py_ready_polars,
cleaning=cleaning,
)
self.allow_caching = allow_caching
[docs]
def build_attributes(self, data: PolarsFrame) -> PolarsFrame:
"""Combines columns prefixed by attributes_ into a struct column called attributes"""
if not self.dynamic:
logger.debug("Skipping attribtues for non-dynamic variables")
return data
attr_cols = [
col
for col in data.collect_schema().names()
if col.startswith("attributes_")
]
if attr_cols:
struct_expr = {col[len("attributes_") :]: pl.col(col) for col in attr_cols}
# Mypy does not recognise the kwargs correctly as an overload
data = data.with_columns(pl.struct(**struct_expr).alias("attributes")) # type: ignore[call-overload]
data = data.drop(attr_cols)
return data
# def on_admission(self, select: str = "!first value") -> NativeStatic:
# """Create a new NativeStatic variable based on the current variable that extracts the value on admission.
# Args:
# select: Select clause specifying aggregation function and columns. Defaults to ``!first value``.
# Returns:
# NativeStatic
# Examples:
# >>> # Return the first value
# >>> var_adm = variable.on_admission()
# >>> cohort.add_variable(var_adm)
# >>> # Be more specific with your selection
# >>> var_adm = variable.on_admission("!closest(hospital_admission,0,2h) value")
# >>> cohort.add_variable(var_adm)
# """
# if not self.dynamic:
# raise ValueError("on_admission is only supported for dynamic variables")
# return NativeStatic(
# var_name=f"{self.var_name}_adm",
# select=select,
# base_var=self.var_name,
# tmin=self.tmin,
# tmax=self.tmax,
# )