Source code for corr_vars.core.variable

from __future__ import annotations

import copy
import textwrap
import warnings

import pandas as pd
import polars as pl

from corr_vars import logger
import corr_vars.utils as utils
from corr_vars.sources.var_loader import load_variable

from corr_vars.definitions.constants import DYN_COLUMNS, COL_ORDER, PRIMARY_KEYS
from corr_vars.definitions.typing import (
    ExtractedVariable,
    PolarsFrame,
)
from corr_vars.utils.time import (
    TimeAnchor,
    TimeWindow,
    WindowOverlap,
    WindowRelation,
    compatible_with_time_window,
)
from typing import Any, TYPE_CHECKING

if TYPE_CHECKING:
    from corr_vars.core import Cohort
    from corr_vars.definitions.typing import (
        TimeAnchorColumn,
        RequirementsIterable,
        RequirementsDict,
        CleaningDict,
        VariableCallable,
    )


[docs] class Variable: """Base class for all variables. Args: var_name (str): The variable name. dynamic (bool): True if the variable is dynamic (time-series). requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []). tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta). tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta). py (Callable): The function to call to calculate the variable. py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe. cleaning (CleaningDict): Dictionary with cleaning rules for the variable. Note: tmin and tmax can be None when you create a Variable object, but must be set before extraction. If you add the variable via cohort.add_variable(), it will be automatically set to the cohort's tmin and tmax. This base class should not be used directly; use one of the subclasses instead. These are specified in the sources submodule. Examples: Basic cleaning configuration: >>> cleaning = { ... "value": { ... "low": 10, ... "high": 80 ... } ... } Time constraints with relative offsets: >>> # Extract data from 2 hours before to 6 hours after ICU admission >>> tmin = ("icu_admission", "-2h") >>> tmax = ("icu_admission", "+6h") Variable with dependencies: >>> # This variable requires other variables to be loaded first >>> requires = ["blood_pressure_sys", "blood_pressure_dia"] >>> # This variable requires other variables with fixed tmin/tmax to be loaded first >>> requires = { ... "blood_pressure_sys_hospital": { ... "template": "blood_pressure_sys" ... "tmin": "hospital_admission", ... "tmax": "hospital_discharge" ... }, ... "blood_pressure_dia_hospital": { ... "template": "blood_pressure_dia" ... "tmin": "hospital_admission", ... "tmax": "hospital_discharge" ... } ... } """ var_name: str dynamic: bool requirements: dict[str, RequirementsDict] required_vars: dict[str, ExtractedVariable] data: pl.DataFrame | None time_window: TimeWindow py: VariableCallable | None py_ready_polars: bool cleaning: CleaningDict | None def __init__( self, var_name: str, dynamic: bool, tmin: TimeAnchorColumn | None, tmax: TimeAnchorColumn | None, requires: RequirementsIterable = [], py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, ) -> None: # Metadata self.var_name = var_name self.dynamic = dynamic # True if the variable is dynamic (time-series) # Tmin / Tmax self.time_window = TimeWindow(tmin=tmin, tmax=tmax, user_specified=True) # Requirements self.requirements = ( requires if isinstance(requires, dict) else {r: {"template": r, "tmin": None, "tmax": None} for r in requires} ) self.required_vars = {} # Variable function self.py = py self.py_ready_polars = py_ready_polars # Cleaning self.cleaning = cleaning # Data self.data = None
[docs] @classmethod def from_time_window(cls, time_window: TimeWindow, *args, **kwargs) -> Variable: var = cls(*args, **kwargs) var.time_window = time_window return var
[docs] def extract(self, cohort: Cohort) -> pl.DataFrame: """ Extracts data from the datasource. Usually follows this pattern. ``` # Load & Extract required variables self._get_required_vars(cohort) # This should change self.data either by returning or side effect self.data = self._custom_extraction(cohort) # Calls variable function to transform extracted data self._call_var_function(cohort) if self.dynamic: self._apply_time_window(cohort) # Expects case_tmin, case_tmax for each primary key self._timefilter() self._apply_cleaning() self._add_relative_times(cohort) self._unify_and_order_columns(cohort.primary_key) ``` """ raise NotImplementedError
def _get_required_vars(self, cohort: Cohort) -> None: """Load required variables for this variable. Args: cohort (Cohort): The cohort instance to load variables from. Returns: None: Sets self.required_vars with loaded Variable objects. """ if len(self.requirements) == 0: logger.debug(f"No dependencies loaded for '{self.var_name}'.") return # Print dependency tree before extraction logger.info(f"Loading dependencies for '{self.var_name}':") utils.log_collection(logger=logger, collection=self.requirements) for save_as, config in self.requirements.items(): var_name = config["template"] if save_as in cohort.constant_vars: if (config["tmin"] and self.time_window.user_specified_tmin) or ( config["tmax"] and self.time_window.user_specified_tmax ): warnings.warn( f"tmin/tmax are ignored for the constant variable {var_name}. " "Please remove tmin/tmax.", UserWarning, ) self.required_vars[save_as] = ExtractedVariable( var_name=var_name, dynamic=False, data=cohort._obs.clone().select(cohort.primary_key, var_name), ) else: var = load_variable( var_name=var_name, cohort=cohort, time_window=TimeWindow( config["tmin"] or self.time_window.tmin, config["tmax"] or self.time_window.tmax, ), ) data = var.extract(cohort) self.required_vars[save_as] = ExtractedVariable( var_name=var.var_name, dynamic=var.dynamic, data=data ) def _call_var_function(self, cohort: Cohort) -> bool: """ Call the variable function if it exists. Args: cohort (Cohort) Returns: True if the function was called, False otherwise. Operations will be applied to Variable.data directly. """ if self.py is None: logger.debug(f"Variable function not called for {self.var_name}") return False code = self.py logger.debug(f"Calling Variable function {self.var_name}") if self.py_ready_polars: data = code(var=self, cohort=copy.deepcopy(cohort)) # Give type checker a hint if not isinstance(data, pl.DataFrame): raise TypeError("Variable function did not return a polars DataFrame") self.data = data else: # Legacy pandas code path logger.warning( f"Converting {self.var_name} to pandas. This is slow, consider rewriting your variable function to accept polars dataframes and then set py_ready_polars=True" ) # Ignore all type errors to keep the rest of the code clean. # Only operate on copies of Cohort and Variable to avoid # pandas Dataframes if an error occurs inside this code block. # Convert variable data to pandas (None for complex variables) pd_var = copy.deepcopy(self) if pd_var.data is not None: pd_var.data = utils.convert_to_pandas_df(pd_var.data) # type: ignore # Convert obs data to pandas pd_cohort = copy.deepcopy(cohort) pd_cohort._obs = utils.convert_to_pandas_df(pd_cohort._obs) # type: ignore # Convert required variables to pandas for reqvar in pd_var.required_vars.values(): if reqvar.data is not None: reqvar.data = utils.convert_to_pandas_df(reqvar.data) # type: ignore data = code(var=pd_var, cohort=pd_cohort) # Give type checker a hint if not isinstance(data, pd.DataFrame): raise TypeError("Variable function did not return a pandas DataFrame") self.data = utils.convert_to_polars_df(data) return True def _add_time_window(self, cohort: Cohort, join_key: str | None = None) -> None: """Adds tmin / tmax / join key / primary key from cohort.obs to self.data. This might duplicate data if the join_key (default: cohort.primary_key) is not unique in obs. Args: cohort (Cohort): The cohort instance containing time constraint information. join_key (str | None): Whether to apply filtering even if case_tmin/case_tmax columns don't exist (default: cohort.primary_key if set to None). Returns: None: Modifies self.data to add tmin / tmax / join key / primary key columns. """ if self.data is None: warnings.warn( "_add_time_window was called before extraction, skipping.", UserWarning ) return obs = cohort._obs.clone() # Remove NA tmin/tmax and tmin/tmax in year 9999, etc. obs = utils.remove_invalid_time_window(obs, time_window=self.time_window) # Return df with "case_tmin" and "case_tmin" columns obs = utils.add_time_window_expr( obs, time_window=self.time_window, aliases=("tmin", "tmax") ) # Return df with "case_tmin" and "case_tmax" columns join_key = join_key or cohort.primary_key cb = utils.get_cb( obs, join_key=join_key, primary_key=cohort.primary_key, tmin_col="tmin", tmax_col="tmax", ) self.data = self.data.join(cb, on=join_key) def _timefilter( self, time_window: TimeWindow | None = None, relation: WindowRelation = WindowRelation.ANY_OVERLAP, ) -> None: """Apply time-based filtering to variable data based on tmin/tmax constraints.""" if self.data is None: warnings.warn( "_timefilter was called before extraction, skipping.", UserWarning ) return window = time_window or TimeWindow("case_tmin", "case_tmax") tmin = window.tmin tmax = window.tmax # If these columns do not exist, we can assume that a filter is not intended for this variable has_time_columns = compatible_with_time_window(self.data, window) if not has_time_columns: logger.warning( f"Time-filter not applied to {self.var_name}, no {tmin} or {tmax} in data and not enforced" ) return start_col = "recordtime" if "recordtime_end" in self.data.columns: # Prepare data end_col = "recordtime_end" end_cleaned_col = "recordtime_end_clean" self.data = ( self.data # Ensure valid interval (end >= start) # Will also fill missing in end_col with values from start_col .with_columns( pl.max_horizontal(start_col, end_col).alias(end_cleaned_col), ) ) # Interval filter on recordtime and recordtime_end record_window = TimeWindow(start_col, end_cleaned_col) self.data = self.data.filter( WindowOverlap.overlap_expr( record=record_window, window=window, relation=relation, ) ) # Remove end_cleaned column self.data = self.data.drop(end_cleaned_col) else: # Simple filter on recordtime record_anchor = TimeAnchor(start_col) self.data = self.data.filter( WindowOverlap.contains_expr(record_anchor, window) ) # Remove tmin / tmax columns if tmin.column: self.data = self.data.drop(tmin.column) if tmax.column: self.data = self.data.drop(tmax.column) logger.info(f"Time-filter selected {len(self.data)} rows, column names:") logger.info(utils.pretty_join(self.data.columns, indent="\t")) def _apply_cleaning(self) -> bool: """Apply cleaning rules to filter out invalid values. Returns: True if the data was cleaned, False otherwise. Operations will be applied to Variable.data directly. """ if self.data is None: warnings.warn( "_apply_cleaning was called before extraction, skipping.", UserWarning ) return False if not self.cleaning: return False for col, bounds in self.cleaning.items(): if ( col == "value" and "value" not in self.data.columns and self.var_name in self.data.columns ): col = self.var_name if "low" in bounds: self.data = self.data.filter(pl.col(col).ge(bounds["low"])) if "high" in bounds: self.data = self.data.filter(pl.col(col).le(bounds["high"])) return True def _add_relative_times(self, cohort: Cohort) -> None: """Add relative time columns to dynamic variable data. Args: cohort (Cohort): The cohort instance containing reference time information. Returns: None: Modifies self.data to include recordtime_relative and recordtime_end_relative columns. """ if self.data is None: warnings.warn( "_add_relative_times was called before extraction, skipping.", UserWarning, ) return # Join with cohort data to get the reference time self.data = utils.convert_to_polars_df(self.data) self.data = self.data.join( cohort._obs.select(cohort.primary_key, cohort.t_min), on=cohort.primary_key, how="left", ) def relative_column( df: pl.DataFrame, time_col: str, reference_col: str ) -> pl.DataFrame: return df.with_columns( pl.col(time_col) .sub(pl.col(reference_col)) .dt.total_seconds() .alias(f"{time_col}_relative") ) if "recordtime" in self.data.columns: # Calculate relative times from icu_admission in seconds self.data = self.data.pipe( relative_column, time_col="recordtime", reference_col=cohort.t_min ) if ( "recordtime_end" in self.data.columns and not self.data["recordtime_end"].is_null().all() ): self.data = self.data.pipe( relative_column, time_col="recordtime_end", reference_col=cohort.t_min ) # Drop the reference time column self.data = self.data.drop(cohort.t_min) def _unify_and_order_columns(self, primary_key: str) -> None: """Order columns according to predefined COL_ORDER and add missing columns with null values. Args: primary_key (str): The primary key column name. Returns: None: Modifies self.data to have standardized column order. """ if self.data is None: warnings.warn( "_unify_and_order_columns was called before extraction, skipping.", UserWarning, ) return assert ( primary_key in PRIMARY_KEYS ), f"Primary key {primary_key} must be one of {PRIMARY_KEYS}." assert ( primary_key in self.data.columns ), f"Primary key {primary_key} not in data columns." # Add missing dynamic columns (no keys) with null values missing_cols = (col for col in DYN_COLUMNS if col not in self.data.columns) if missing_cols: self.data = self.data.with_columns( pl.lit(None).alias(col) for col in missing_cols ) # Order columns according to COL_ORDER, then add remaining columns data_cols = set(self.data.columns) ordered_cols = [col for col in COL_ORDER if col in data_cols] remaining_cols = [col for col in data_cols if col not in COL_ORDER] if len(remaining_cols) > 0: warnings.warn( "🚨 DEPRECATION ALERT 🚨\n" "Unknown columns found.\n" "- Please update the variable definition to move these to attributes.\n" "- Unknown columns are no longer supported and might be deprecated soon!\n" "- Columns: " + ", ".join(remaining_cols), UserWarning, ) self.data = self.data.select(ordered_cols + remaining_cols) # Uncomment after user warning # self.data = self.data.select(ordered_cols) # System methods def __repr__(self) -> str: return self.__str__() def __str__(self) -> str: return textwrap.dedent( f"""\ Variable: {self.var_name} ({self.__class__.__name__}, {'dynamic' if self.dynamic else 'static'}) ├── time_window: {self.time_window} └── data: {'Not extracted' if self.data is None else self.data.shape} """ ).strip() def __getstate__(self) -> dict[str, Any]: return self.__dict__ def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state)
[docs] class NativeVariable(Variable): """Extended Variable class for native variables from data sources. Args: var_name (str): The variable name. dynamic (bool): True if the variable is dynamic (time-series) (default: True). requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []). tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None). tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None). py (Callable): The function to call to calculate the variable (default: None). py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe (default: False). cleaning (CleaningDict): Dictionary with cleaning rules for the variable (default: None). allow_caching (bool): Whether to allow caching of this variable (default: True). """ def __init__( self, # Parent class parameters var_name: str, dynamic: bool, requires: RequirementsIterable = [], tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, # This class parameters allow_caching: bool = True, ) -> None: super().__init__( var_name, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, ) self.allow_caching = allow_caching
[docs] def build_attributes(self, data: PolarsFrame) -> PolarsFrame: """Combines columns prefixed by attributes_ into a struct column called attributes""" if not self.dynamic: logger.debug("Skipping attribtues for non-dynamic variables") return data attr_cols = [ col for col in data.collect_schema().names() if col.startswith("attributes_") ] if attr_cols: struct_expr = {col[len("attributes_") :]: pl.col(col) for col in attr_cols} # Mypy does not recognise the kwargs correctly as an overload data = data.with_columns(pl.struct(**struct_expr).alias("attributes")) # type: ignore[call-overload] data = data.drop(attr_cols) return data
# def on_admission(self, select: str = "!first value") -> NativeStatic: # """Create a new NativeStatic variable based on the current variable that extracts the value on admission. # Args: # select: Select clause specifying aggregation function and columns. Defaults to ``!first value``. # Returns: # NativeStatic # Examples: # >>> # Return the first value # >>> var_adm = variable.on_admission() # >>> cohort.add_variable(var_adm) # >>> # Be more specific with your selection # >>> var_adm = variable.on_admission("!closest(hospital_admission,0,2h) value") # >>> cohort.add_variable(var_adm) # """ # if not self.dynamic: # raise ValueError("on_admission is only supported for dynamic variables") # return NativeStatic( # var_name=f"{self.var_name}_adm", # select=select, # base_var=self.var_name, # tmin=self.tmin, # tmax=self.tmax, # )