Source code for corr_vars.core.variable

from __future__ import annotations

import copy
from datetime import datetime, date
from enum import Flag, auto
import textwrap
import warnings

import pandas as pd
import polars as pl

from corr_vars import logger
import corr_vars.utils as utils
from corr_vars.sources.var_loader import load_variable

from typing import Any, TypeAlias, TYPE_CHECKING
from corr_vars.definitions.constants import DYN_COLUMNS, COL_ORDER, PRIMARY_KEYS
from corr_vars.definitions.typing import (
    ExtractedVariable,
    PolarsFrame,
    is_time_anchor_column,
)

if TYPE_CHECKING:
    from corr_vars.core import Cohort
    from corr_vars.definitions.typing import (
        TimeAnchorColumn,
        RequirementsIterable,
        RequirementsDict,
        CleaningDict,
        VariableCallable,
    )

DateLiteral: TypeAlias = datetime | date | int


class TimeAnchor:
    column: str | None
    delta: str | None
    expr: pl.Expr

    def __init__(self, bound: TimeAnchorColumn | DateLiteral) -> None:
        if (
            not is_time_anchor_column(bound)
            and not isinstance(bound, datetime)
            and not isinstance(bound, date)
        ):
            raise TypeError(
                "Expected TimeBoundColumn to be a string or a tuple of two strings or a polars expression."
            )

        if isinstance(bound, datetime) or isinstance(bound, date):
            self.column = None
            self.delta = None
            self.expr = pl.lit(bound)
        elif isinstance(bound, str):
            self.column = bound
            self.delta = None
            self.expr = pl.col(self.column)
        else:
            self.column, delta = bound
            self.delta = delta[1:] if delta[0] == "+" else delta
            self.expr = pl.col(self.column).dt.offset_by(self.delta)

        self.expr = self.expr.alias("time")

    def __repr__(self) -> str:
        if self.column is None or self.delta is None:
            return self.column or (
                f"Literal: {pl.select(self.expr).item()}"
                if self.expr.meta.is_literal()
                else ", ".join(self.expr.meta.root_names())
            )
        else:
            op = "-" if self.delta[0] == "-" else "+"
            rest = self.delta[1:] if self.delta[0] == "-" else self.delta
            return f"{self.column} {op} {rest}"

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, TimeAnchor):
            return False
        return (
            self.column == other.column
            and self.delta == other.delta
            and self.expr.meta.eq(other.expr)
        )

    @property
    def is_relative(self) -> bool:
        return self.column is not None and self.delta is not None


class TimeWindow:
    _tmin: TimeAnchor
    _tmax: TimeAnchor
    user_specified_tmin: bool
    user_specified_tmax: bool
    frozen: bool

    def __init__(
        self,
        tmin: TimeAnchorColumn | DateLiteral | None = None,
        tmax: TimeAnchorColumn | DateLiteral | None = None,
        user_specified: bool | tuple[bool, bool] = False,
        frozen: bool = False,
    ) -> None:
        self._tmin = TimeAnchor(tmin or datetime.min)
        self._tmax = TimeAnchor(tmax or datetime.max)
        self.user_specified_tmin = (
            user_specified[0] if isinstance(user_specified, tuple) else user_specified
        )
        self.user_specified_tmax = (
            user_specified[1] if isinstance(user_specified, tuple) else user_specified
        )
        self.frozen = frozen

    def __repr__(self) -> str:
        return f"TimeWindow(tmin={self._tmin}, tmax={self._tmax})"

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, TimeWindow):
            return False
        return self._tmin == other._tmin and self._tmax == other._tmax

    def freeze(self) -> None:
        self.frozen = True

    @property
    def user_specified(self) -> bool:
        return self.user_specified_tmin or self.user_specified_tmax

    @property
    def tmin(self) -> TimeAnchor:
        return self._tmin

    @property
    def tmax(self) -> TimeAnchor:
        return self._tmax

    @tmin.setter
    def tmin(self, value: TimeAnchor) -> None:
        if self.frozen:
            raise ValueError("Time bounds are frozen")
        self._tmin = value

    @tmax.setter
    def tmax(self, value: TimeAnchor) -> None:
        if self.frozen:
            raise ValueError("Time bounds are frozen")
        self._tmax = value

    @property
    def tmin_expr(self, alias: str = "tmin") -> pl.Expr:
        return self._tmin.expr.alias(alias)

    @property
    def tmax_expr(self, alias: str = "tmax") -> pl.Expr:
        return self._tmax.expr.alias(alias)

    @property
    def duration_expr(self, alias: str = "duration") -> pl.Expr:
        return (self.tmax_expr - self.tmin_expr).alias(alias)

    @property
    def tmid_expr(self, alias: str = "tmin") -> pl.Expr:
        return self.tmin_expr.add(self.duration_expr).alias(alias)

    @property
    def is_relative(self) -> bool:
        return self.tmin.is_relative or self.tmax.is_relative

    @property
    def any_timebound_in_year_9999_expr(self) -> pl.Expr:
        return pl.any_horizontal(
            self.tmin_expr.dt.year().eq(pl.lit(9999)).fill_null(False),
            self.tmax_expr.dt.year().eq(pl.lit(9999)).fill_null(False),
        )

    @property
    def any_timebound_null_expr(self) -> pl.Expr:
        return pl.any_horizontal(
            self.tmin_expr.is_null(),
            self.tmax_expr.is_null(),
        )

    @property
    def tmin_after_tmax_expr(self) -> pl.Expr:
        return self.tmin_expr > self.tmax_expr

    @property
    def any_timebound_invalid_expr(self) -> pl.Expr:
        return pl.any_horizontal(
            self.any_timebound_null_expr,
            self.any_timebound_in_year_9999_expr,
            self.tmin_after_tmax_expr,
        )


class IntervalRelation(Flag):
    CONTAINED = auto()  # fully inside [tmin, tmax]
    OVERLAPS_LEFT = auto()  # starts before tmin and ends inside
    OVERLAPS_RIGHT = auto()  # starts inside and ends after tmax
    COVERS_WINDOW = auto()  # spans over both boundaries

    # Common combinations
    ANY_LEFT_OVERLAP = OVERLAPS_LEFT | CONTAINED | COVERS_WINDOW
    ANY_RIGHT_OVERLAP = OVERLAPS_RIGHT | CONTAINED | COVERS_WINDOW
    ANY_OVERLAP = CONTAINED | OVERLAPS_LEFT | OVERLAPS_RIGHT | COVERS_WINDOW


class IntervalOverlap:
    @staticmethod
    def relation_expr(
        record: TimeWindow,
        window: TimeWindow,
    ) -> dict[IntervalRelation, pl.Expr]:
        start = record.tmin_expr
        end = record.tmax_expr
        tmin = window.tmin_expr
        tmax = window.tmax_expr

        return {
            IntervalRelation.CONTAINED: (start >= tmin) & (end <= tmax),
            IntervalRelation.OVERLAPS_LEFT: (start < tmin)
            & (end >= tmin)
            & (end <= tmax),
            IntervalRelation.OVERLAPS_RIGHT: (start >= tmin)
            & (start <= tmax)
            & (end > tmax),
            IntervalRelation.COVERS_WINDOW: (start < tmin) & (end > tmax),
        }

    @staticmethod
    def overlap_expr(
        record: TimeWindow,
        window: TimeWindow,
        relation: IntervalRelation,
    ) -> pl.Expr:
        expr_map = IntervalOverlap.relation_expr(record, window)

        expr = None
        for flag, cond in expr_map.items():
            if relation & flag:
                expr = cond if expr is None else expr | cond

        if expr is None:
            raise ValueError(f"Invalid relation: {relation}")

        return expr

    @staticmethod
    def classify_expr(
        record: TimeWindow,
        window: TimeWindow,
        *,
        default: IntervalRelation | None = None,
    ) -> pl.Expr:
        exprs = IntervalOverlap.relation_expr(record, window)

        when = pl.when(exprs[IntervalRelation.CONTAINED]).then(
            pl.lit(IntervalRelation.CONTAINED.value)
        )

        for rel in (
            IntervalRelation.OVERLAPS_LEFT,
            IntervalRelation.OVERLAPS_RIGHT,
            IntervalRelation.COVERS_WINDOW,
        ):
            when = when.when(exprs[rel]).then(pl.lit(rel.value))

        if default is not None:
            when = when.otherwise(pl.lit(default.value))
        else:
            when = when.otherwise(pl.lit(None))

        return when


[docs] class Variable: """Base class for all variables. Args: var_name (str): The variable name. dynamic (bool): True if the variable is dynamic (time-series). requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []). tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta). tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta). py (Callable): The function to call to calculate the variable. py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe. cleaning (CleaningDict): Dictionary with cleaning rules for the variable. Note: tmin and tmax can be None when you create a Variable object, but must be set before extraction. If you add the variable via cohort.add_variable(), it will be automatically set to the cohort's tmin and tmax. This base class should not be used directly; use one of the subclasses instead. These are specified in the sources submodule. Examples: Basic cleaning configuration: >>> cleaning = { ... "value": { ... "low": 10, ... "high": 80 ... } ... } Time constraints with relative offsets: >>> # Extract data from 2 hours before to 6 hours after ICU admission >>> tmin = ("icu_admission", "-2h") >>> tmax = ("icu_admission", "+6h") Variable with dependencies: >>> # This variable requires other variables to be loaded first >>> requires = ["blood_pressure_sys", "blood_pressure_dia"] >>> # This variable requires other variables with fixed tmin/tmax to be loaded first >>> requires = { ... "blood_pressure_sys_hospital": { ... "template": "blood_pressure_sys" ... "tmin": "hospital_admission", ... "tmax": "hospital_discharge" ... }, ... "blood_pressure_dia_hospital": { ... "template": "blood_pressure_dia" ... "tmin": "hospital_admission", ... "tmax": "hospital_discharge" ... } ... } """ var_name: str dynamic: bool requirements: dict[str, RequirementsDict] required_vars: dict[str, ExtractedVariable] data: pl.DataFrame | None time_window: TimeWindow py: VariableCallable | None py_ready_polars: bool cleaning: CleaningDict | None def __init__( self, var_name: str, dynamic: bool, tmin: TimeAnchorColumn | None, tmax: TimeAnchorColumn | None, requires: RequirementsIterable = [], py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, ) -> None: # Metadata self.var_name = var_name self.dynamic = dynamic # True if the variable is dynamic (time-series) # Tmin / Tmax self.time_window = TimeWindow(tmin=tmin, tmax=tmax, user_specified=True) # Requirements self.requirements = ( requires if isinstance(requires, dict) else {r: {"template": r, "tmin": None, "tmax": None} for r in requires} ) self.required_vars = {} # Variable function self.py = py self.py_ready_polars = py_ready_polars # Cleaning self.cleaning = cleaning # Data self.data = None
[docs] @classmethod def from_time_window( cls, var_name: str, dynamic: bool, time_window: TimeWindow, requires: RequirementsIterable = [], py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, ) -> Variable: var = cls( var_name=var_name, dynamic=dynamic, tmin=None, tmax=None, requires=requires, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, ) var.time_window = time_window return var
[docs] def extract(self, cohort: Cohort) -> pl.DataFrame: """ Extracts data from the datasource. Usually follows this pattern. ``` # Load & Extract required variables self._get_required_vars(cohort) # This should change self.data either by returning or side effect self.data = self._custom_extraction(cohort) # Or convert_polars=False for pandas functions self._call_var_function(cohort) if self.dynamic: self._apply_time_window(cohort) # Expects case_tmin, case_tmax for each primary key self._timefilter() self._apply_cleaning() self._add_relative_times(cohort) self._unify_and_order_columns(cohort.primary_key) ``` """ raise NotImplementedError
def _get_required_vars(self, cohort: Cohort) -> None: """Load required variables for this variable. Args: cohort (Cohort): The cohort instance to load variables from. Returns: None: Sets self.required_vars with loaded Variable objects. """ if len(self.requirements) == 0: logger.debug(f"No dependencies loaded for '{self.var_name}'.") return # Print dependency tree before extraction logger.info(f"Loading dependencies for '{self.var_name}':") utils.log_collection(logger=logger, collection=self.requirements) for save_as, config in self.requirements.items(): var_name = config["template"] if save_as in cohort.constant_vars: if (config["tmin"] and self.time_window.user_specified_tmin) or ( config["tmax"] and self.time_window.user_specified_tmax ): warnings.warn( f"tmin/tmax are ignored for the constant variable {var_name}. " "Please remove tmin/tmax.", UserWarning, ) self.required_vars[save_as] = ExtractedVariable( var_name=var_name, dynamic=False, data=cohort._obs.clone().select(cohort.primary_key, var_name), ) else: var = load_variable( var_name=var_name, cohort=cohort, time_window=TimeWindow(config["tmin"], config["tmax"]), ) data = var.extract(cohort) self.required_vars[save_as] = ExtractedVariable( var_name=var.var_name, dynamic=var.dynamic, data=data ) def _call_var_function(self, cohort: Cohort) -> bool: """ Call the variable function if it exists. Args: cohort (Cohort) Returns: True if the function was called, False otherwise. Operations will be applied to Variable.data directly. """ if self.py is None: logger.debug(f"Variable function not called for {self.var_name}") return False code = self.py logger.debug(f"Calling Variable function {self.var_name}") if self.py_ready_polars: data = code(var=self, cohort=copy.deepcopy(cohort)) # Give type checker a hint if not isinstance(data, pl.DataFrame): raise TypeError("Variable function did not return a polars DataFrame") self.data = data else: # Legacy pandas code path logger.warning( f"Converting {self.var_name} to pandas. This is slow, consider rewriting your variable function to accept polars dataframes and then set py_ready_polars=True" ) # Ignore all type errors to keep the rest of the code clean. # Only operate on copies of Cohort and Variable to avoid # pandas Dataframes if an error occurs inside this code block. # Convert variable data to pandas (None for complex variables) pd_var = copy.deepcopy(self) if pd_var.data is not None: pd_var.data = utils.convert_to_pandas_df(pd_var.data) # type: ignore # Convert obs data to pandas pd_cohort = copy.deepcopy(cohort) pd_cohort._obs = utils.convert_to_pandas_df(pd_cohort._obs) # type: ignore # Convert required variables to pandas for reqvar in pd_var.required_vars.values(): if reqvar.data is not None: reqvar.data = utils.convert_to_pandas_df(reqvar.data) # type: ignore data = code(var=pd_var, cohort=pd_cohort) # Give type checker a hint if not isinstance(data, pd.DataFrame): raise TypeError("Variable function did not return a pandas DataFrame") self.data = utils.convert_to_polars_df(data) return True def _add_time_window(self, cohort: Cohort, join_key: str | None = None) -> None: """Adds tmin / tmax / join key / primary key from cohort.obs to self.data. This might duplicate data if the join_key (default: cohort.primary_key) is not unique in obs. Args: cohort (Cohort): The cohort instance containing time constraint information. join_key (str | None): Whether to apply filtering even if case_tmin/case_tmax columns don't exist (default: cohort.primary_key if set to None). Returns: None: Modifies self.data to add tmin / tmax / join key / primary key columns. """ if self.data is None: warnings.warn( "_add_time_window was called before extraction, skipping.", UserWarning ) return obs = cohort._obs.clone() # Remove NA tmin/tmax and tmin/tmax in year 9999, etc. obs = utils.remove_invalid_time_window(obs, time_window=self.time_window) # Return df with tmin and tmax columns obs = utils.add_time_window_expr(obs, time_window=self.time_window) # Return df with case_tmin and case_tmax columns join_key = join_key or cohort.primary_key cb = utils.get_cb( obs, join_key=join_key, primary_key=cohort.primary_key, ) self.data = self.data.join(cb, on=join_key) def _timefilter( self, tmin_col: str = "case_tmin", tmax_col: str = "case_tmax", relation: IntervalRelation = IntervalRelation.ANY_OVERLAP, ) -> None: """Apply time-based filtering to variable data based on tmin/tmax constraints.""" if self.data is None: warnings.warn( "_timefilter was called before extraction, skipping.", UserWarning ) return # If these columns do not exist, we can assume that a filter is not intended for this variable has_time_columns = ( tmin_col in self.data.columns and tmax_col in self.data.columns ) if not has_time_columns: logger.warning( f"Time-filter not applied to {self.var_name}, no {tmin_col} or {tmax_col} in data and not enforced" ) return start_col = "recordtime" if "recordtime_end" in self.data.columns: # Prepare data end_col = "recordtime_end" end_cleaned_col = "recordtime_end_clean" self.data = ( self.data # Fill missings .with_columns( pl.coalesce(end_col, start_col).alias(end_cleaned_col), ) # Ensure valid interval (end >= start) .with_columns( pl.max_horizontal(start_col, end_cleaned_col).alias( end_cleaned_col ), ) ) record = TimeWindow(start_col, end_cleaned_col) window = TimeWindow(tmin_col, tmax_col) self.data = self.data.filter( IntervalOverlap.overlap_expr( record=record, window=window, relation=relation, ) ) # Remove end_cleaned column self.data = self.data.drop(end_cleaned_col) else: # Simple filter on recordtime self.data = self.data.filter( pl.col(start_col).is_between(tmin_col, tmax_col) ) # Remove tmin / tmax columns self.data = self.data.drop(tmin_col, tmax_col) logger.info(f"Time-filter selected {len(self.data)} rows, column names:") logger.info(utils.pretty_join(self.data.columns, indent="\t")) def _apply_cleaning(self) -> bool: """Apply cleaning rules to filter out invalid values. Returns: True if the data was cleaned, False otherwise. Operations will be applied to Variable.data directly. """ if self.data is None: warnings.warn( "_apply_cleaning was called before extraction, skipping.", UserWarning ) return False if not self.cleaning: return False for col, bounds in self.cleaning.items(): if ( col == "value" and "value" not in self.data.columns and self.var_name in self.data.columns ): col = self.var_name if "low" in bounds: self.data = self.data.filter(pl.col(col).ge(bounds["low"])) if "high" in bounds: self.data = self.data.filter(pl.col(col).le(bounds["high"])) return True def _add_relative_times(self, cohort: Cohort) -> None: """Add relative time columns to dynamic variable data. Args: cohort (Cohort): The cohort instance containing reference time information. Returns: None: Modifies self.data to include recordtime_relative and recordtime_end_relative columns. """ if self.data is None: warnings.warn( "_add_relative_times was called before extraction, skipping.", UserWarning, ) return # Join with cohort data to get the reference time self.data = utils.convert_to_polars_df(self.data) self.data = self.data.join( cohort._obs.select(cohort.primary_key, cohort.t_min), on=cohort.primary_key, how="left", ) def relative_column( df: pl.DataFrame, time_col: str, reference_col: str ) -> pl.DataFrame: return df.with_columns( pl.col(time_col) .sub(pl.col(reference_col)) .dt.total_seconds() .alias(f"{time_col}_relative") ) if "recordtime" in self.data.columns: # Calculate relative times from icu_admission in seconds self.data = self.data.pipe( relative_column, time_col="recordtime", reference_col=cohort.t_min ) if ( "recordtime_end" in self.data.columns and not self.data["recordtime_end"].is_null().all() ): self.data = self.data.pipe( relative_column, time_col="recordtime_end", reference_col=cohort.t_min ) # Drop the reference time column self.data = self.data.drop(cohort.t_min) def _unify_and_order_columns(self, primary_key: str) -> None: """Order columns according to predefined COL_ORDER and add missing columns with null values. Args: primary_key (str): The primary key column name. Returns: None: Modifies self.data to have standardized column order. """ if self.data is None: warnings.warn( "_unify_and_order_columns was called before extraction, skipping.", UserWarning, ) return assert ( primary_key in PRIMARY_KEYS ), f"Primary key {primary_key} must be one of {PRIMARY_KEYS}." assert ( primary_key in self.data.columns ), f"Primary key {primary_key} not in data columns." # Add missing dynamic columns (no keys) with null values missing_cols = (col for col in DYN_COLUMNS if col not in self.data.columns) if missing_cols: self.data = self.data.with_columns( pl.lit(None).alias(col) for col in missing_cols ) # Order columns according to COL_ORDER, then add remaining columns data_cols = set(self.data.columns) ordered_cols = [col for col in COL_ORDER if col in data_cols] remaining_cols = [col for col in data_cols if col not in COL_ORDER] if len(remaining_cols) > 0: warnings.warn( "🚨 DEPRECATION ALERT 🚨\n" "Unknown columns found.\n" "- Please update the variable definition to move these to attributes.\n" "- Unknown columns are no longer supported and might be deprecated soon!\n" "- Columns: " + ", ".join(remaining_cols), UserWarning, ) self.data = self.data.select(ordered_cols + remaining_cols) # Uncomment after user warning # self.data = self.data.select(ordered_cols) # System methods def __repr__(self) -> str: return self.__str__() def __str__(self) -> str: return textwrap.dedent( f"""\ Variable: {self.var_name} ({self.__class__.__name__}, {'dynamic' if self.dynamic else 'static'}) ├── time_window: {self.time_window} └── data: {'Not extracted' if self.data is None else self.data.shape} """ ).strip() def __getstate__(self) -> dict[str, Any]: return self.__dict__ def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state)
[docs] class NativeVariable(Variable): """Extended Variable class for native variables from data sources. Args: var_name (str): The variable name. dynamic (bool): True if the variable is dynamic (time-series) (default: True). requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []). tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None). tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None). py (Callable): The function to call to calculate the variable (default: None). py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe (default: False). cleaning (CleaningDict): Dictionary with cleaning rules for the variable (default: None). allow_caching (bool): Whether to allow caching of this variable (default: True). """ def __init__( self, # Parent class parameters var_name: str, dynamic: bool, requires: RequirementsIterable = [], tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, # This class parameters allow_caching: bool = True, ) -> None: super().__init__( var_name, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, ) self.allow_caching = allow_caching
[docs] def build_attributes(self, data: PolarsFrame) -> PolarsFrame: """Combines columns prefixed by attributes_ into a struct column called attributes""" if not self.dynamic: logger.debug("Skipping attribtues for non-dynamic variables") return data attr_cols = [ col for col in data.collect_schema().names() if col.startswith("attributes_") ] if attr_cols: struct_expr = {col[len("attributes_") :]: pl.col(col) for col in attr_cols} # Mypy does not recognise the kwargs correctly as an overload data = data.with_columns(pl.struct(**struct_expr).alias("attributes")) # type: ignore[call-overload] data = data.drop(attr_cols) return data
# def on_admission(self, select: str = "!first value") -> NativeStatic: # """Create a new NativeStatic variable based on the current variable that extracts the value on admission. # Args: # select: Select clause specifying aggregation function and columns. Defaults to ``!first value``. # Returns: # NativeStatic # Examples: # >>> # Return the first value # >>> var_adm = variable.on_admission() # >>> cohort.add_variable(var_adm) # >>> # Be more specific with your selection # >>> var_adm = variable.on_admission("!closest(hospital_admission,0,2h) value") # >>> cohort.add_variable(var_adm) # """ # if not self.dynamic: # raise ValueError("on_admission is only supported for dynamic variables") # return NativeStatic( # var_name=f"{self.var_name}_adm", # select=select, # base_var=self.var_name, # tmin=self.tmin, # tmax=self.tmax, # )