from __future__ import annotations
import copy
from datetime import datetime, date
from enum import Flag, auto
import textwrap
import warnings
import pandas as pd
import polars as pl
from corr_vars import logger
import corr_vars.utils as utils
from corr_vars.sources.var_loader import load_variable
from typing import Any, TypeAlias, TYPE_CHECKING
from corr_vars.definitions.constants import DYN_COLUMNS, COL_ORDER, PRIMARY_KEYS
from corr_vars.definitions.typing import (
ExtractedVariable,
PolarsFrame,
is_time_anchor_column,
)
if TYPE_CHECKING:
from corr_vars.core import Cohort
from corr_vars.definitions.typing import (
TimeAnchorColumn,
RequirementsIterable,
RequirementsDict,
CleaningDict,
VariableCallable,
)
DateLiteral: TypeAlias = datetime | date | int
class TimeAnchor:
column: str | None
delta: str | None
expr: pl.Expr
def __init__(self, bound: TimeAnchorColumn | DateLiteral) -> None:
if (
not is_time_anchor_column(bound)
and not isinstance(bound, datetime)
and not isinstance(bound, date)
):
raise TypeError(
"Expected TimeBoundColumn to be a string or a tuple of two strings or a polars expression."
)
if isinstance(bound, datetime) or isinstance(bound, date):
self.column = None
self.delta = None
self.expr = pl.lit(bound)
elif isinstance(bound, str):
self.column = bound
self.delta = None
self.expr = pl.col(self.column)
else:
self.column, delta = bound
self.delta = delta[1:] if delta[0] == "+" else delta
self.expr = pl.col(self.column).dt.offset_by(self.delta)
self.expr = self.expr.alias("time")
def __repr__(self) -> str:
if self.column is None or self.delta is None:
return self.column or (
f"Literal: {pl.select(self.expr).item()}"
if self.expr.meta.is_literal()
else ", ".join(self.expr.meta.root_names())
)
else:
op = "-" if self.delta[0] == "-" else "+"
rest = self.delta[1:] if self.delta[0] == "-" else self.delta
return f"{self.column} {op} {rest}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, TimeAnchor):
return False
return (
self.column == other.column
and self.delta == other.delta
and self.expr.meta.eq(other.expr)
)
@property
def is_relative(self) -> bool:
return self.column is not None and self.delta is not None
class TimeWindow:
_tmin: TimeAnchor
_tmax: TimeAnchor
user_specified_tmin: bool
user_specified_tmax: bool
frozen: bool
def __init__(
self,
tmin: TimeAnchorColumn | DateLiteral | None = None,
tmax: TimeAnchorColumn | DateLiteral | None = None,
user_specified: bool | tuple[bool, bool] = False,
frozen: bool = False,
) -> None:
self._tmin = TimeAnchor(tmin or datetime.min)
self._tmax = TimeAnchor(tmax or datetime.max)
self.user_specified_tmin = (
user_specified[0] if isinstance(user_specified, tuple) else user_specified
)
self.user_specified_tmax = (
user_specified[1] if isinstance(user_specified, tuple) else user_specified
)
self.frozen = frozen
def __repr__(self) -> str:
return f"TimeWindow(tmin={self._tmin}, tmax={self._tmax})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, TimeWindow):
return False
return self._tmin == other._tmin and self._tmax == other._tmax
def freeze(self) -> None:
self.frozen = True
@property
def user_specified(self) -> bool:
return self.user_specified_tmin or self.user_specified_tmax
@property
def tmin(self) -> TimeAnchor:
return self._tmin
@property
def tmax(self) -> TimeAnchor:
return self._tmax
@tmin.setter
def tmin(self, value: TimeAnchor) -> None:
if self.frozen:
raise ValueError("Time bounds are frozen")
self._tmin = value
@tmax.setter
def tmax(self, value: TimeAnchor) -> None:
if self.frozen:
raise ValueError("Time bounds are frozen")
self._tmax = value
@property
def tmin_expr(self, alias: str = "tmin") -> pl.Expr:
return self._tmin.expr.alias(alias)
@property
def tmax_expr(self, alias: str = "tmax") -> pl.Expr:
return self._tmax.expr.alias(alias)
@property
def duration_expr(self, alias: str = "duration") -> pl.Expr:
return (self.tmax_expr - self.tmin_expr).alias(alias)
@property
def tmid_expr(self, alias: str = "tmin") -> pl.Expr:
return self.tmin_expr.add(self.duration_expr).alias(alias)
@property
def is_relative(self) -> bool:
return self.tmin.is_relative or self.tmax.is_relative
@property
def any_timebound_in_year_9999_expr(self) -> pl.Expr:
return pl.any_horizontal(
self.tmin_expr.dt.year().eq(pl.lit(9999)).fill_null(False),
self.tmax_expr.dt.year().eq(pl.lit(9999)).fill_null(False),
)
@property
def any_timebound_null_expr(self) -> pl.Expr:
return pl.any_horizontal(
self.tmin_expr.is_null(),
self.tmax_expr.is_null(),
)
@property
def tmin_after_tmax_expr(self) -> pl.Expr:
return self.tmin_expr > self.tmax_expr
@property
def any_timebound_invalid_expr(self) -> pl.Expr:
return pl.any_horizontal(
self.any_timebound_null_expr,
self.any_timebound_in_year_9999_expr,
self.tmin_after_tmax_expr,
)
class IntervalRelation(Flag):
CONTAINED = auto() # fully inside [tmin, tmax]
OVERLAPS_LEFT = auto() # starts before tmin and ends inside
OVERLAPS_RIGHT = auto() # starts inside and ends after tmax
COVERS_WINDOW = auto() # spans over both boundaries
# Common combinations
ANY_LEFT_OVERLAP = OVERLAPS_LEFT | CONTAINED | COVERS_WINDOW
ANY_RIGHT_OVERLAP = OVERLAPS_RIGHT | CONTAINED | COVERS_WINDOW
ANY_OVERLAP = CONTAINED | OVERLAPS_LEFT | OVERLAPS_RIGHT | COVERS_WINDOW
class IntervalOverlap:
@staticmethod
def relation_expr(
record: TimeWindow,
window: TimeWindow,
) -> dict[IntervalRelation, pl.Expr]:
start = record.tmin_expr
end = record.tmax_expr
tmin = window.tmin_expr
tmax = window.tmax_expr
return {
IntervalRelation.CONTAINED: (start >= tmin) & (end <= tmax),
IntervalRelation.OVERLAPS_LEFT: (start < tmin)
& (end >= tmin)
& (end <= tmax),
IntervalRelation.OVERLAPS_RIGHT: (start >= tmin)
& (start <= tmax)
& (end > tmax),
IntervalRelation.COVERS_WINDOW: (start < tmin) & (end > tmax),
}
@staticmethod
def overlap_expr(
record: TimeWindow,
window: TimeWindow,
relation: IntervalRelation,
) -> pl.Expr:
expr_map = IntervalOverlap.relation_expr(record, window)
expr = None
for flag, cond in expr_map.items():
if relation & flag:
expr = cond if expr is None else expr | cond
if expr is None:
raise ValueError(f"Invalid relation: {relation}")
return expr
@staticmethod
def classify_expr(
record: TimeWindow,
window: TimeWindow,
*,
default: IntervalRelation | None = None,
) -> pl.Expr:
exprs = IntervalOverlap.relation_expr(record, window)
when = pl.when(exprs[IntervalRelation.CONTAINED]).then(
pl.lit(IntervalRelation.CONTAINED.value)
)
for rel in (
IntervalRelation.OVERLAPS_LEFT,
IntervalRelation.OVERLAPS_RIGHT,
IntervalRelation.COVERS_WINDOW,
):
when = when.when(exprs[rel]).then(pl.lit(rel.value))
if default is not None:
when = when.otherwise(pl.lit(default.value))
else:
when = when.otherwise(pl.lit(None))
return when
[docs]
class Variable:
"""Base class for all variables.
Args:
var_name (str): The variable name.
dynamic (bool): True if the variable is dynamic (time-series).
requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []).
tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta).
tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta).
py (Callable): The function to call to calculate the variable.
py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe.
cleaning (CleaningDict): Dictionary with cleaning rules for the variable.
Note:
tmin and tmax can be None when you create a Variable object, but must be set before extraction.
If you add the variable via cohort.add_variable(), it will be automatically set to the cohort's tmin and tmax.
This base class should not be used directly; use one of the subclasses instead. These are specified in the sources submodule.
Examples:
Basic cleaning configuration:
>>> cleaning = {
... "value": {
... "low": 10,
... "high": 80
... }
... }
Time constraints with relative offsets:
>>> # Extract data from 2 hours before to 6 hours after ICU admission
>>> tmin = ("icu_admission", "-2h")
>>> tmax = ("icu_admission", "+6h")
Variable with dependencies:
>>> # This variable requires other variables to be loaded first
>>> requires = ["blood_pressure_sys", "blood_pressure_dia"]
>>> # This variable requires other variables with fixed tmin/tmax to be loaded first
>>> requires = {
... "blood_pressure_sys_hospital": {
... "template": "blood_pressure_sys"
... "tmin": "hospital_admission",
... "tmax": "hospital_discharge"
... },
... "blood_pressure_dia_hospital": {
... "template": "blood_pressure_dia"
... "tmin": "hospital_admission",
... "tmax": "hospital_discharge"
... }
... }
"""
var_name: str
dynamic: bool
requirements: dict[str, RequirementsDict]
required_vars: dict[str, ExtractedVariable]
data: pl.DataFrame | None
time_window: TimeWindow
py: VariableCallable | None
py_ready_polars: bool
cleaning: CleaningDict | None
def __init__(
self,
var_name: str,
dynamic: bool,
tmin: TimeAnchorColumn | None,
tmax: TimeAnchorColumn | None,
requires: RequirementsIterable = [],
py: VariableCallable | None = None,
py_ready_polars: bool = False,
cleaning: CleaningDict | None = None,
) -> None:
# Metadata
self.var_name = var_name
self.dynamic = dynamic # True if the variable is dynamic (time-series)
# Tmin / Tmax
self.time_window = TimeWindow(tmin=tmin, tmax=tmax, user_specified=True)
# Requirements
self.requirements = (
requires
if isinstance(requires, dict)
else {r: {"template": r, "tmin": None, "tmax": None} for r in requires}
)
self.required_vars = {}
# Variable function
self.py = py
self.py_ready_polars = py_ready_polars
# Cleaning
self.cleaning = cleaning
# Data
self.data = None
[docs]
@classmethod
def from_time_window(
cls,
var_name: str,
dynamic: bool,
time_window: TimeWindow,
requires: RequirementsIterable = [],
py: VariableCallable | None = None,
py_ready_polars: bool = False,
cleaning: CleaningDict | None = None,
) -> Variable:
var = cls(
var_name=var_name,
dynamic=dynamic,
tmin=None,
tmax=None,
requires=requires,
py=py,
py_ready_polars=py_ready_polars,
cleaning=cleaning,
)
var.time_window = time_window
return var
def _get_required_vars(self, cohort: Cohort) -> None:
"""Load required variables for this variable.
Args:
cohort (Cohort): The cohort instance to load variables from.
Returns:
None: Sets self.required_vars with loaded Variable objects.
"""
if len(self.requirements) == 0:
logger.debug(f"No dependencies loaded for '{self.var_name}'.")
return
# Print dependency tree before extraction
logger.info(f"Loading dependencies for '{self.var_name}':")
utils.log_collection(logger=logger, collection=self.requirements)
for save_as, config in self.requirements.items():
var_name = config["template"]
if save_as in cohort.constant_vars:
if (config["tmin"] and self.time_window.user_specified_tmin) or (
config["tmax"] and self.time_window.user_specified_tmax
):
warnings.warn(
f"tmin/tmax are ignored for the constant variable {var_name}. "
"Please remove tmin/tmax.",
UserWarning,
)
self.required_vars[save_as] = ExtractedVariable(
var_name=var_name,
dynamic=False,
data=cohort._obs.clone().select(cohort.primary_key, var_name),
)
else:
var = load_variable(
var_name=var_name,
cohort=cohort,
time_window=TimeWindow(config["tmin"], config["tmax"]),
)
data = var.extract(cohort)
self.required_vars[save_as] = ExtractedVariable(
var_name=var.var_name, dynamic=var.dynamic, data=data
)
def _call_var_function(self, cohort: Cohort) -> bool:
"""
Call the variable function if it exists.
Args:
cohort (Cohort)
Returns:
True if the function was called,
False otherwise.
Operations will be applied to Variable.data directly.
"""
if self.py is None:
logger.debug(f"Variable function not called for {self.var_name}")
return False
code = self.py
logger.debug(f"Calling Variable function {self.var_name}")
if self.py_ready_polars:
data = code(var=self, cohort=copy.deepcopy(cohort))
# Give type checker a hint
if not isinstance(data, pl.DataFrame):
raise TypeError("Variable function did not return a polars DataFrame")
self.data = data
else:
# Legacy pandas code path
logger.warning(
f"Converting {self.var_name} to pandas. This is slow, consider rewriting your variable function to accept polars dataframes and then set py_ready_polars=True"
)
# Ignore all type errors to keep the rest of the code clean.
# Only operate on copies of Cohort and Variable to avoid
# pandas Dataframes if an error occurs inside this code block.
# Convert variable data to pandas (None for complex variables)
pd_var = copy.deepcopy(self)
if pd_var.data is not None:
pd_var.data = utils.convert_to_pandas_df(pd_var.data) # type: ignore
# Convert obs data to pandas
pd_cohort = copy.deepcopy(cohort)
pd_cohort._obs = utils.convert_to_pandas_df(pd_cohort._obs) # type: ignore
# Convert required variables to pandas
for reqvar in pd_var.required_vars.values():
if reqvar.data is not None:
reqvar.data = utils.convert_to_pandas_df(reqvar.data) # type: ignore
data = code(var=pd_var, cohort=pd_cohort)
# Give type checker a hint
if not isinstance(data, pd.DataFrame):
raise TypeError("Variable function did not return a pandas DataFrame")
self.data = utils.convert_to_polars_df(data)
return True
def _add_time_window(self, cohort: Cohort, join_key: str | None = None) -> None:
"""Adds tmin / tmax / join key / primary key from cohort.obs to self.data.
This might duplicate data if the join_key (default: cohort.primary_key) is not unique in obs.
Args:
cohort (Cohort): The cohort instance containing time constraint information.
join_key (str | None): Whether to apply filtering even if case_tmin/case_tmax columns don't exist (default: cohort.primary_key if set to None).
Returns:
None: Modifies self.data to add tmin / tmax / join key / primary key columns.
"""
if self.data is None:
warnings.warn(
"_add_time_window was called before extraction, skipping.", UserWarning
)
return
obs = cohort._obs.clone()
# Remove NA tmin/tmax and tmin/tmax in year 9999, etc.
obs = utils.remove_invalid_time_window(obs, time_window=self.time_window)
# Return df with tmin and tmax columns
obs = utils.add_time_window_expr(obs, time_window=self.time_window)
# Return df with case_tmin and case_tmax columns
join_key = join_key or cohort.primary_key
cb = utils.get_cb(
obs,
join_key=join_key,
primary_key=cohort.primary_key,
)
self.data = self.data.join(cb, on=join_key)
def _timefilter(
self,
tmin_col: str = "case_tmin",
tmax_col: str = "case_tmax",
relation: IntervalRelation = IntervalRelation.ANY_OVERLAP,
) -> None:
"""Apply time-based filtering to variable data based on tmin/tmax constraints."""
if self.data is None:
warnings.warn(
"_timefilter was called before extraction, skipping.", UserWarning
)
return
# If these columns do not exist, we can assume that a filter is not intended for this variable
has_time_columns = (
tmin_col in self.data.columns and tmax_col in self.data.columns
)
if not has_time_columns:
logger.warning(
f"Time-filter not applied to {self.var_name}, no {tmin_col} or {tmax_col} in data and not enforced"
)
return
start_col = "recordtime"
if "recordtime_end" in self.data.columns:
# Prepare data
end_col = "recordtime_end"
end_cleaned_col = "recordtime_end_clean"
self.data = (
self.data
# Fill missings
.with_columns(
pl.coalesce(end_col, start_col).alias(end_cleaned_col),
)
# Ensure valid interval (end >= start)
.with_columns(
pl.max_horizontal(start_col, end_cleaned_col).alias(
end_cleaned_col
),
)
)
record = TimeWindow(start_col, end_cleaned_col)
window = TimeWindow(tmin_col, tmax_col)
self.data = self.data.filter(
IntervalOverlap.overlap_expr(
record=record,
window=window,
relation=relation,
)
)
# Remove end_cleaned column
self.data = self.data.drop(end_cleaned_col)
else:
# Simple filter on recordtime
self.data = self.data.filter(
pl.col(start_col).is_between(tmin_col, tmax_col)
)
# Remove tmin / tmax columns
self.data = self.data.drop(tmin_col, tmax_col)
logger.info(f"Time-filter selected {len(self.data)} rows, column names:")
logger.info(utils.pretty_join(self.data.columns, indent="\t"))
def _apply_cleaning(self) -> bool:
"""Apply cleaning rules to filter out invalid values.
Returns:
True if the data was cleaned,
False otherwise.
Operations will be applied to Variable.data directly.
"""
if self.data is None:
warnings.warn(
"_apply_cleaning was called before extraction, skipping.", UserWarning
)
return False
if not self.cleaning:
return False
for col, bounds in self.cleaning.items():
if (
col == "value"
and "value" not in self.data.columns
and self.var_name in self.data.columns
):
col = self.var_name
if "low" in bounds:
self.data = self.data.filter(pl.col(col).ge(bounds["low"]))
if "high" in bounds:
self.data = self.data.filter(pl.col(col).le(bounds["high"]))
return True
def _add_relative_times(self, cohort: Cohort) -> None:
"""Add relative time columns to dynamic variable data.
Args:
cohort (Cohort): The cohort instance containing reference time information.
Returns:
None: Modifies self.data to include recordtime_relative and recordtime_end_relative columns.
"""
if self.data is None:
warnings.warn(
"_add_relative_times was called before extraction, skipping.",
UserWarning,
)
return
# Join with cohort data to get the reference time
self.data = utils.convert_to_polars_df(self.data)
self.data = self.data.join(
cohort._obs.select(cohort.primary_key, cohort.t_min),
on=cohort.primary_key,
how="left",
)
def relative_column(
df: pl.DataFrame, time_col: str, reference_col: str
) -> pl.DataFrame:
return df.with_columns(
pl.col(time_col)
.sub(pl.col(reference_col))
.dt.total_seconds()
.alias(f"{time_col}_relative")
)
if "recordtime" in self.data.columns:
# Calculate relative times from icu_admission in seconds
self.data = self.data.pipe(
relative_column, time_col="recordtime", reference_col=cohort.t_min
)
if (
"recordtime_end" in self.data.columns
and not self.data["recordtime_end"].is_null().all()
):
self.data = self.data.pipe(
relative_column, time_col="recordtime_end", reference_col=cohort.t_min
)
# Drop the reference time column
self.data = self.data.drop(cohort.t_min)
def _unify_and_order_columns(self, primary_key: str) -> None:
"""Order columns according to predefined COL_ORDER and add missing columns with null values.
Args:
primary_key (str): The primary key column name.
Returns:
None: Modifies self.data to have standardized column order.
"""
if self.data is None:
warnings.warn(
"_unify_and_order_columns was called before extraction, skipping.",
UserWarning,
)
return
assert (
primary_key in PRIMARY_KEYS
), f"Primary key {primary_key} must be one of {PRIMARY_KEYS}."
assert (
primary_key in self.data.columns
), f"Primary key {primary_key} not in data columns."
# Add missing dynamic columns (no keys) with null values
missing_cols = (col for col in DYN_COLUMNS if col not in self.data.columns)
if missing_cols:
self.data = self.data.with_columns(
pl.lit(None).alias(col) for col in missing_cols
)
# Order columns according to COL_ORDER, then add remaining columns
data_cols = set(self.data.columns)
ordered_cols = [col for col in COL_ORDER if col in data_cols]
remaining_cols = [col for col in data_cols if col not in COL_ORDER]
if len(remaining_cols) > 0:
warnings.warn(
"🚨 DEPRECATION ALERT 🚨\n"
"Unknown columns found.\n"
"- Please update the variable definition to move these to attributes.\n"
"- Unknown columns are no longer supported and might be deprecated soon!\n"
"- Columns: " + ", ".join(remaining_cols),
UserWarning,
)
self.data = self.data.select(ordered_cols + remaining_cols)
# Uncomment after user warning
# self.data = self.data.select(ordered_cols)
# System methods
def __repr__(self) -> str:
return self.__str__()
def __str__(self) -> str:
return textwrap.dedent(
f"""\
Variable: {self.var_name} ({self.__class__.__name__}, {'dynamic' if self.dynamic else 'static'})
├── time_window: {self.time_window}
└── data: {'Not extracted' if self.data is None else self.data.shape}
"""
).strip()
def __getstate__(self) -> dict[str, Any]:
return self.__dict__
def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)
[docs]
class NativeVariable(Variable):
"""Extended Variable class for native variables from data sources.
Args:
var_name (str): The variable name.
dynamic (bool): True if the variable is dynamic (time-series) (default: True).
requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []).
tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None).
tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None).
py (Callable): The function to call to calculate the variable (default: None).
py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe (default: False).
cleaning (CleaningDict): Dictionary with cleaning rules for the variable (default: None).
allow_caching (bool): Whether to allow caching of this variable (default: True).
"""
def __init__(
self,
# Parent class parameters
var_name: str,
dynamic: bool,
requires: RequirementsIterable = [],
tmin: TimeAnchorColumn | None = None,
tmax: TimeAnchorColumn | None = None,
py: VariableCallable | None = None,
py_ready_polars: bool = False,
cleaning: CleaningDict | None = None,
# This class parameters
allow_caching: bool = True,
) -> None:
super().__init__(
var_name,
dynamic=dynamic,
requires=requires,
tmin=tmin,
tmax=tmax,
py=py,
py_ready_polars=py_ready_polars,
cleaning=cleaning,
)
self.allow_caching = allow_caching
[docs]
def build_attributes(self, data: PolarsFrame) -> PolarsFrame:
"""Combines columns prefixed by attributes_ into a struct column called attributes"""
if not self.dynamic:
logger.debug("Skipping attribtues for non-dynamic variables")
return data
attr_cols = [
col
for col in data.collect_schema().names()
if col.startswith("attributes_")
]
if attr_cols:
struct_expr = {col[len("attributes_") :]: pl.col(col) for col in attr_cols}
# Mypy does not recognise the kwargs correctly as an overload
data = data.with_columns(pl.struct(**struct_expr).alias("attributes")) # type: ignore[call-overload]
data = data.drop(attr_cols)
return data
# def on_admission(self, select: str = "!first value") -> NativeStatic:
# """Create a new NativeStatic variable based on the current variable that extracts the value on admission.
# Args:
# select: Select clause specifying aggregation function and columns. Defaults to ``!first value``.
# Returns:
# NativeStatic
# Examples:
# >>> # Return the first value
# >>> var_adm = variable.on_admission()
# >>> cohort.add_variable(var_adm)
# >>> # Be more specific with your selection
# >>> var_adm = variable.on_admission("!closest(hospital_admission,0,2h) value")
# >>> cohort.add_variable(var_adm)
# """
# if not self.dynamic:
# raise ValueError("on_admission is only supported for dynamic variables")
# return NativeStatic(
# var_name=f"{self.var_name}_adm",
# select=select,
# base_var=self.var_name,
# tmin=self.tmin,
# tmax=self.tmax,
# )