Source code for corr_vars.core.variable

from __future__ import annotations

import copy
import textwrap
import warnings

import pandas as pd
import polars as pl

from corr_vars import logger
import corr_vars.utils as utils
from corr_vars.sources.var_loader import load_variable

from typing import Any, TYPE_CHECKING
from corr_vars.definitions.constants import DYN_COLUMNS, COL_ORDER, PRIMARY_KEYS
from corr_vars.definitions.typing import ExtractedVariable, PolarsFrame

if TYPE_CHECKING:
    from corr_vars.core import Cohort
    from corr_vars.definitions.typing import (
        TimeBoundColumn,
        RequirementsIterable,
        RequirementsDict,
        CleaningDict,
        VariableCallable,
    )


[docs] class Variable: """Base class for all variables. Args: var_name (str): The variable name. dynamic (bool): True if the variable is dynamic (time-series). requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []). tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta). tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta). py (Callable): The function to call to calculate the variable. py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe. cleaning (CleaningDict): Dictionary with cleaning rules for the variable. Note: tmin and tmax can be None when you create a Variable object, but must be set before extraction. If you add the variable via cohort.add_variable(), it will be automatically set to the cohort's tmin and tmax. This base class should not be used directly; use one of the subclasses instead. These are specified in the sources submodule. Examples: Basic cleaning configuration: >>> cleaning = { ... "value": { ... "low": 10, ... "high": 80 ... } ... } Time constraints with relative offsets: >>> # Extract data from 2 hours before to 6 hours after ICU admission >>> tmin = ("icu_admission", "-2h") >>> tmax = ("icu_admission", "+6h") Variable with dependencies: >>> # This variable requires other variables to be loaded first >>> requires = ["blood_pressure_sys", "blood_pressure_dia"] >>> # This variable requires other variables with fixed tmin/tmax to be loaded first >>> requires = { ... "blood_pressure_sys_hospital": { ... "template": "blood_pressure_sys" ... "tmin": "hospital_admission", ... "tmax": "hospital_discharge" ... }, ... "blood_pressure_dia_hospital": { ... "template": "blood_pressure_dia" ... "tmin": "hospital_admission", ... "tmax": "hospital_discharge" ... } ... } """ var_name: str dynamic: bool requires: RequirementsIterable required_vars: dict[str, ExtractedVariable] data: pl.DataFrame | None tmin: TimeBoundColumn | None tmax: TimeBoundColumn | None py: VariableCallable | None py_ready_polars: bool cleaning: CleaningDict | None def __init__( self, var_name: str, dynamic: bool, requires: RequirementsIterable = [], tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, ) -> None: self.var_name = var_name self.dynamic = dynamic # True if the variable is dynamic (time-series) self.requires = requires self.required_vars = {} self.data = None self.tmin = tmin self.tmax = tmax self.py = py self.py_ready_polars = py_ready_polars self.cleaning = cleaning
[docs] def extract(self, cohort: Cohort) -> pl.DataFrame: """ Extracts data from the datasource. Usually follows this pattern for dynamic (timeseries) data. ``` self._get_required_vars(cohort) # This should change self.data either by returning or side effect self.data = self._custom_extraction(cohort) # Or convert_polars=False for pandas functions self._call_var_function(cohort, convert_polars=True) # Expects case_tmin, case_tmax for each primary key self._timefilter(cohort, always=not self.complex) self._apply_cleaning() self._add_relative_times(cohort) self._unify_and_order_columns(cohort.primary_key) ``` """ raise NotImplementedError
def _get_required_vars(self, cohort: Cohort) -> None: """Load required variables for this variable. Args: cohort (Cohort): The cohort instance to load variables from. Returns: None: Sets self.required_vars with loaded Variable objects. """ if len(self.requires) == 0: return # Print dependency tree before extraction logger.info(f"Loading dependencies for {self.var_name}:") utils.log_collection(logger=logger, collection=self.requires) for var_name in self.requires: overrides: RequirementsDict = {} if isinstance(self.requires, dict): overrides = copy.deepcopy(self.requires[var_name]) if var_name in cohort.constant_vars: if overrides.get("tmin") or overrides.get("tmax"): warnings.warn( f"tmin/tmax are ignored for the constant variable {var_name}. Please remove tmin/tmax.", UserWarning, ) self.required_vars[var_name] = ExtractedVariable( var_name=var_name, dynamic=False, data=cohort._obs.clone().select(cohort.primary_key, var_name), ) else: overrides.setdefault("template", var_name) overrides.setdefault("tmin", self.tmin) overrides.setdefault("tmax", self.tmax) var = load_variable( var_name=overrides["template"], cohort=cohort, tmin=overrides["tmin"], tmax=overrides["tmax"], ) data = var.extract(cohort) self.required_vars[var_name] = ExtractedVariable( var_name=var.var_name, dynamic=var.dynamic, data=data ) def _call_var_function(self, cohort: Cohort, convert_polars: bool = False) -> bool: """ Call the variable function if it exists. Args: cohort (Cohort) Returns: True if the function was called, False otherwise. Operations will be applied to Variable.data directly. """ if self.py is None: return False code = self.py logger.debug(f"Calling Variable function {self.var_name}") if self.py_ready_polars: self.data = code(var=self, cohort=copy.deepcopy(cohort)) # Give type checker a hint assert isinstance( self.data, pl.DataFrame ), "Variable function did not return a polars DataFrame" if not convert_polars: self.data = utils.convert_to_pandas_df(self.data) else: # Legacy pandas code path logger.warning( f"Converting {self.var_name} to pandas. This is slow, consider rewriting your variable function to accept polars dataframes and then set py_ready_polars=True" ) if self.data is not None: self.data = utils.convert_to_pandas_df(self.data) pd_cohort = copy.deepcopy(cohort) pd_cohort._obs = pd_cohort._obs.to_pandas() # Convert required variables to pandas # These do not need to be converted back as they are only used here for reqvar in self.required_vars.values(): if reqvar.data is not None: reqvar.data = reqvar.data.to_pandas() self.data = code(var=self, cohort=pd_cohort) # Give type checker a hint assert isinstance( self.data, pd.DataFrame ), "Variable function did not return a pandas DataFrame" if convert_polars: self.data = utils.convert_to_polars_df(self.data) return True def _timefilter( self, cohort: Cohort, always: bool = False, join_key: str = "case_id" ) -> None: """Apply time-based filtering to variable data based on tmin/tmax constraints. Args: cohort (Cohort): The cohort instance containing time constraint information. always (bool): Whether to apply filtering even if case_tmin/case_tmax columns don't exist (default: False). Returns: None: Modifies self.data to filter rows within time constraints. """ if self.data is None: warnings.warn( "Timefilter was called before extraction, skipping.", UserWarning ) return if always: df = cohort._obs.clone() df = utils.pl_parse_time_args(df, tmin=self.tmin, tmax=self.tmax) cb = utils.get_cb( df, join_key=join_key, primary_key=cohort.primary_key ) # Return df with case_tmin and case_tmax columns self.data = self.data.join(cb, on=join_key) # If these columns do not exist, we can assume that a filter is not intended for this variable has_columns = ( "case_tmin" in self.data.columns and "case_tmax" in self.data.columns ) if has_columns: if "recordtime_end" in self.data.columns: # Filter for any rows overlapping with case_tmin and case_tmax self.data = ( self.data # Fill missings .with_columns( pl.when(pl.col("recordtime_end").is_null()) .then(pl.col("recordtime")) .otherwise(pl.col("recordtime_end")) .alias("recordtime_end_clean"), ) # Ensure valid interval (end >= start), swap if necessary .with_columns( pl.when(pl.col("recordtime_end_clean") < pl.col("recordtime")) .then(pl.col("recordtime")) .otherwise(pl.col("recordtime_end_clean")) .alias("recordtime_end_clean"), ) # Calculate overlap .with_columns( pl.max_horizontal(["recordtime", "case_tmin"]).alias( "overlap_start" ), pl.min_horizontal(["recordtime_end_clean", "case_tmax"]).alias( "overlap_end" ), ) # Filter to only intervals that have some overlap .filter(pl.col("overlap_start") <= pl.col("overlap_end")) # Drop intermediate columns .drop( "overlap_start", "overlap_end", "recordtime_end_clean", "case_tmin", "case_tmax", ) ) else: # Simple filter on recordtime self.data = self.data.filter( pl.col("recordtime").is_between("case_tmin", "case_tmax") ).drop("case_tmin", "case_tmax") logger.info(f"Time-filter selected {len(self.data)} rows, column names:") logger.info(utils.pretty_join(self.data.columns, indent="\t")) else: logger.warning( f"Time-filter not applied to {self.var_name}, no case_tmin or case_tmax in data and not enforced" ) def _apply_cleaning(self) -> None: """Apply cleaning rules to filter out invalid values. Returns: pl.DataFrame: The cleaned data. """ if self.data is None: warnings.warn( "Cleaning was called before extraction, skipping.", UserWarning ) return if self.cleaning: for col, bounds in self.cleaning.items(): if ( col == "value" and "value" not in self.data.columns and self.var_name in self.data.columns ): col = self.var_name if "low" in bounds: self.data = self.data.filter(pl.col(col).ge(bounds["low"])) if "high" in bounds: self.data = self.data.filter(pl.col(col).le(bounds["high"])) def _add_relative_times(self, cohort: Cohort) -> None: """Add relative time columns to dynamic variable data. Args: cohort (Cohort): The cohort instance containing reference time information. Returns: None: Modifies self.data to include recordtime_relative and recordtime_end_relative columns. """ if self.data is None: warnings.warn( "Relative time was called before extraction, skipping.", UserWarning ) return # Join with cohort data to get the reference time self.data = utils.convert_to_polars_df(self.data) self.data = self.data.join( cohort._obs.select(cohort.primary_key, cohort.t_min), on=cohort.primary_key, how="left", ) def relative_column( df: pl.DataFrame, time_col: str, reference_col: str ) -> pl.DataFrame: return df.with_columns( pl.col(time_col) .sub(pl.col(reference_col)) .dt.total_seconds() .alias(f"{time_col}_relative") ) if "recordtime" in self.data.columns: # Calculate relative times from icu_admission in seconds self.data = self.data.pipe( relative_column, time_col="recordtime", reference_col=cohort.t_min ) if ( "recordtime_end" in self.data.columns and not self.data.select(pl.col("recordtime_end").is_null().all()).item() ): self.data = self.data.pipe( relative_column, time_col="recordtime_end", reference_col=cohort.t_min ) # Drop the reference time column self.data = self.data.drop(cohort.t_min) def _unify_and_order_columns(self, primary_key: str) -> None: """Order columns according to predefined COL_ORDER and add missing columns with null values. Args: primary_key (str): The primary key column name. Returns: None: Modifies self.data to have standardized column order. """ if self.data is None: warnings.warn( "Unify and order was called before extraction, skipping.", UserWarning ) return assert ( primary_key in PRIMARY_KEYS ), f"Primary key {primary_key} must be one of {PRIMARY_KEYS}." assert ( primary_key in self.data.columns ), f"Primary key {primary_key} not in data columns." # Add missing dynamic columns (no keys) with null values missing_cols = (col for col in DYN_COLUMNS if col not in self.data.columns) if missing_cols: self.data = self.data.with_columns( pl.lit(None).alias(col) for col in missing_cols ) # Order columns according to COL_ORDER, then add remaining columns data_cols = set(self.data.columns) ordered_cols = [col for col in COL_ORDER if col in data_cols] remaining_cols = [col for col in data_cols if col not in COL_ORDER] if len(remaining_cols) > 0: logger.warning( "🚨 DEPRECATION ALERT 🚨 \nUnknown columns found.\n - Please update the variable definition to move these to attributes. \n - Unknown columns are no longer supported and might be deprecated soon! \n - Columns:" + ", ".join(remaining_cols) ) self.data = self.data.select(ordered_cols + remaining_cols) # Uncomment after user warning # self.data = self.data.select(ordered_cols) # System methods def __repr__(self) -> str: return self.__str__() def __str__(self) -> str: return textwrap.dedent( f"""\ Variable: {self.var_name} ({self.__class__.__name__}, {'dynamic' if self.dynamic else 'static'}) ├── tmin: {self.tmin}, tmax: {self.tmax} └── data: {'Not extracted' if self.data is None else self.data.shape} """ ).strip() def __getstate__(self) -> dict[str, Any]: return self.__dict__ def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state)
[docs] class NativeVariable(Variable): """Extended Variable class for native variables from data sources. Args: var_name (str): The variable name. dynamic (bool): True if the variable is dynamic (time-series) (default: True). requires (RequirementsIterable): List of variables or dict of variables with tmin/tmax required to calculate the variable (default: []). tmin (TimeBoundColumn | None): The tmin argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None). tmax (TimeBoundColumn | None): The tmax argument. Can either be a string (column name) or a tuple of (column name, timedelta) (default: None). py (Callable): The function to call to calculate the variable (default: None). py_ready_polars (bool): True if the variable code can accept polars dataframes as input and return a polars dataframe (default: False). cleaning (CleaningDict): Dictionary with cleaning rules for the variable (default: None). allow_caching (bool): Whether to allow caching of this variable (default: True). """ def __init__( self, # Parent class parameters var_name: str, dynamic: bool, requires: RequirementsIterable = [], tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, # This class parameters allow_caching: bool = True, ) -> None: super().__init__( var_name, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, ) self.allow_caching = allow_caching
[docs] def build_attributes(self, data: PolarsFrame) -> PolarsFrame: """Combines columns prefixed by attributes_ into a struct column called attributes""" if not self.dynamic: logger.debug("Skipping attribtues for non-dynamic variables") return data attr_cols = [ col for col in data.collect_schema().names() if col.startswith("attributes_") ] if attr_cols: struct_expr = {col[len("attributes_") :]: pl.col(col) for col in attr_cols} # Mypy does not recognise the kwargs correctly as an overload data = data.with_columns(pl.struct(**struct_expr).alias("attributes")) # type: ignore[call-overload] data = data.drop(attr_cols) return data
# def on_admission(self, select: str = "!first value") -> NativeStatic: # """Create a new NativeStatic variable based on the current variable that extracts the value on admission. # Args: # select: Select clause specifying aggregation function and columns. Defaults to ``!first value``. # Returns: # NativeStatic # Examples: # >>> # Return the first value # >>> var_adm = variable.on_admission() # >>> cohort.add_variable(var_adm) # >>> # Be more specific with your selection # >>> var_adm = variable.on_admission("!closest(hospital_admission,0,2h) value") # >>> cohort.add_variable(var_adm) # """ # if not self.dynamic: # raise ValueError("on_admission is only supported for dynamic variables") # return NativeStatic( # var_name=f"{self.var_name}_adm", # select=select, # base_var=self.var_name, # tmin=self.tmin, # tmax=self.tmax, # )