Source code for corr_vars.sources.cub_hdp.extract

from __future__ import annotations
import os
import pickle

import pandas as pd
import polars as pl

from corr_vars import logger
from corr_vars.core.variable import NativeVariable
from corr_vars.core.variable import Variable as BaseVariable
from corr_vars.sources.var_loader import load_variable, MultiSourceVariable
from corr_vars.sources.cub_hdp import helpers
import corr_vars.utils as utils

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from corr_vars.core import Cohort
    from corr_vars.definitions import (
        TimeBoundColumn,
        RequirementsIterable,
        CleaningDict,
        VariableCallable,
    )


[docs] def Variable(*args, **kwargs) -> BaseVariable: """Factory function to create the correct variable type based on the arguments.""" logger.info("Creating variable with:") if args: prefix = "└── " if not kwargs else "├── " logger.info(f"{prefix}args: {args}") if kwargs: prefix = "└── " logger.info(f"{prefix}kwargs:") utils.log_dict(logger=logger, dictionary=kwargs, indent=4) # Legacy naming error if kwargs.get("variable"): kwargs["base_var"] = kwargs.pop("variable") try: var_type = kwargs.pop("type", None) # Overwrite dynamic attribute if var_type in [ "native_dynamic", "derived_dynamic", ]: kwargs["dynamic"] = True # Add complex attribute if var_type == "complex": kwargs["complex"] = True except KeyError: logger.error("CubHDP Variable requires the 'type' field.") if var_type in ["native_dynamic", "complex"]: return NativeDynamic(*args, **kwargs) elif var_type == "derived_dynamic": return DerivedDynamic(*args, **kwargs) elif var_type == "native_static": return NativeStatic(*args, **kwargs) elif var_type == "derived_static": return DerivedStatic(*args, **kwargs) else: raise KeyError(f"Type {var_type} does not exist.")
class NativeDynamic(NativeVariable): """ Variable class for the cub_hdp source. Args: var_name: Name of the variable dynamic: Whether the variable is dynamic requires: List of required variables tmin: Minimum time point tmax: Maximum time point py: Python function to apply to the variable py_ready_polars: Whether the variable is already polars ready cleaning: Dictionary of cleaning functions allow_caching: Whether to allow caching table: Name of the table to extract from where: WHERE clause to filter the table value_dtype: Data type of the value column columns: List of columns to extract complex: Whether the variable is complex (Will skip the database extraction -> Must specify query in py) Returns: None Note: Only use this class for defining custom variables. For CORR-specified variables, use cohort.add_variable() directly. """ def __init__( self, # core.variable.Variable var_name: str, dynamic: bool, requires: RequirementsIterable = [], tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, # core.variable.NativeVariable(Variable) allow_caching: bool = True, # sources.cub_hdp.extract.Variable(NativeVariable) table: str | None = None, where: str | None = None, value_dtype: str | None = None, columns: dict[str, str | dict[str, str]] | None = None, complex: bool = False, ) -> None: super().__init__( var_name=var_name, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, allow_caching=allow_caching, ) self.table = table self.where = where self.value_dtype = value_dtype self.columns = columns self.complex = complex # Apply timefilter always for non-complex variables self._timefilter_always = not self.complex def extract(self, cohort: Cohort) -> pl.DataFrame: assert self.tmin is not None, "tmin must be provided for all variables" assert self.tmax is not None, "tmax must be provided for all variables" self.data = self._get_from_cache(cohort) if self.data is None: self._get_required_vars(cohort) # NativeDynamic variables if not self.complex: self.data = self._extract_from_db(cohort) function_called = self._call_var_function(cohort, convert_polars=True) if self.data is None: if self.complex and not function_called: raise ValueError( "Complex variables need to extract data inside a custom variable function." ) else: raise ValueError("Variable extraction resulted in no data.") self.data = self.build_attributes(self.data) self._write_to_cache(cohort) if self.dynamic: # Compled variables are time-filtered only if case_tmin and case_tmax are prepared in the variable function self._timefilter(cohort, always=self._timefilter_always) self._apply_cleaning() self._add_relative_times(cohort) self._unify_and_order_columns(cohort.primary_key) return self.data def _extract_from_db(self, cohort: Cohort) -> pl.DataFrame: """Extract data from the database.""" if self.complex: raise RuntimeError( "Complex NativeDynamic should not call _extract_from_db." ) with helpers.temporary_table(cohort) as temp_tbl: query, table_info, keep_cols = helpers.build_base_query( self, helpers._get_cub_hdp_database(cohort), temp_tbl ) conn = helpers.reinit_conn(cohort) result_df: pl.LazyFrame = conn.lf_read_sql( query=query, datetime_cols=[ col for col in table_info if "recordtime" in col or "storetime" in col ], keep_cols=keep_cols, ) logger.info(f"Extracted {result_df.select(pl.len()).collect().item()} rows") result_df = self.build_attributes(result_df) return result_df.collect() def _get_from_cache(self, cohort: Cohort) -> pl.DataFrame | None: path = cohort.tmpdir_manager.create_tmpdir_variable_path( var_name=self.var_name, extension=None ) if self.allow_caching and os.path.exists(f"{path}.parquet"): # Before loading, ensure that cacheinfo exists and all case_ids from cohort.obs are present if not os.path.exists(f"{path}.cacheinfo"): return None with open(f"{path}.cacheinfo", "rb") as f: cache_case_ids = pickle.load(f) cohort_case_ids = set( cohort.obs.select("case_id").unique().to_series().to_list() ) # Return none if cohort.obs contains uncached case_ids if not cohort_case_ids.issubset(cache_case_ids): return None self.data = pl.read_parquet(f"{path}.parquet") logger.info( f"Loaded {self.var_name} from tmpdir ({self.data.shape[0]} rows)" ) if "case_id" in self.data.columns: self.data = self.data.filter(pl.col("case_id").is_in(cohort_case_ids)) logger.info(f"Found {self.data.shape[0]} rows matching cohort.obs") if "case_tmin" in self.data.columns and "case_tmax" in self.data.columns: # For complex variables that have precomputed case_tmin and case_tmax # The timefilter needs to be reapplied - we will set timefilter_always to True # For this, we need to drop the case_tmin, case_tmax non case_id primary keys drop_cols = ["case_tmin", "case_tmax"] + ( [cohort.primary_key] if cohort.primary_key != "case_id" else [] ) self.data = self.data.drop(drop_cols, strict=False) self._timefilter_always = True return self.data return None def _write_to_cache(self, cohort: Cohort) -> None: if self.allow_caching and self.data is not None: path = cohort.tmpdir_manager.create_tmpdir_variable_path( var_name=self.var_name, extension=None ) self.data.write_parquet(f"{path}.parquet") case_ids = cohort.obs.select("case_id").unique().to_series().to_list() if os.path.exists(f"{path}.cacheinfo"): os.remove(f"{path}.cacheinfo") # Write cacheinfo with open(f"{path}.cacheinfo", "wb") as f: pickle.dump(case_ids, f) logger.info(f"Wrote {self.var_name} to tmpdir ({self.data.shape[0]} rows)") class NativeStatic(BaseVariable): """Aggregated variables represent simple aggregations of dynamic variables. Args: var_name (str): Name of the variable. select (str): Select clause specifying aggregation function and columns. base_var (str): Name of the base variable (must be a native_dynamic variable). where (str, optional): Optional WHERE clause (in format for polars). tmin (str, optional): Minimum time for the extraction. tmax (str, optional): Maximum time for the extraction. The select argument supports several aggregation functions: - ``!first [columns]``: Returns the first row within this case >>> "!first value" # Single column >>> "!first value, recordtime" # Multiple columns - ``!last [columns]``: Returns the last row within this case >>> "!last value" >>> "!last value, recordtime" - ``!any``: Returns True if any value exists >>> "!any" >>> "!any value" - ``!sum [column]``: Calculates sum of values >>> "!sum value" - ``!closest(to_column, timedelta, plusminus) [columns]``: Selects value closest to specified column Args: to_column: Column to compare "recordtime" against timedelta: Time to add to "to_column" for comparison plusminus: Allowed time mismatch (can specify different before/after with space) >>> "!closest(hospital_admission) value, recordtime" # Closest to admission >>> "!closest(hospital_admission, 0, 2h 3h) value" # 2h before to 3h after >>> "!closest(first_intubation_dtime, 6h, 2h) value" # 6h after intubation ±2h - ``!mean [column]``: Calculates mean value >>> "!mean value" - ``!median [column]``: Calculates median value >>> "!median value" - ``!perc(quantile) [column]``: Calculates specified percentile >>> "!perc(75) value" # 75th percentile The where argument supports Pandas-style boolean expressions. These are evaluated in the context of the base variable by pd.eval(). Where also supports magic commands (starting with !) to filter the data. Supported commands are: - ``!isin(column, [values])``: Filters rows where the value in column is in values - ``!startswith(column, [values])``: Filters rows where the value in column starts with any of the values - ``!endswith(column, [values])``: Filters rows where the value in column ends with any of the values """ base_var: BaseVariable | MultiSourceVariable def __init__( self, var_name: str, select: str, base_var: str | BaseVariable, where: str | None = None, tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, cleaning: CleaningDict | None = None, ) -> None: super().__init__( var_name, dynamic=False, tmin=tmin, tmax=tmax, cleaning=cleaning, ) self.select = select logger.debug( f"NativeStatic: SELECT {self.select} FROM {base_var} WHERE {where}" ) self.where = where self._base_var = base_var self.agg_func, self.agg_params, self.select_cols = helpers.parse_select( self.select ) if len(self.select_cols) == 0: self.select_cols = ["value"] def extract(self, cohort: Cohort) -> pl.DataFrame: """Extract the variable. You do not need to call this yourself, as it is called internally when you add the variable to a cohort. However, you may call it directly to obtain variable data independently of the cohort. You still need a cohort object for case ids and other metadata. Args: cohort: Cohort object. Returns: Extracted variable. After extraction, you may also access the data as ``Variable.data``. Examples: >>> var = NativeStatic( ... var_name="first_sodium_recordtime", ... select="!first recordtime", ... base_var="blood_sodium", ... tmin="hospital_admission" ... ) >>> var.extract(cohort) # With var.extract(), the data will not be added to the cohort. >>> var.data # You can access the data as a polars dataframe. """ assert ( self.tmin is not None ), "tmin must be provided for all variables upon extraction" assert ( self.tmax is not None ), "tmax must be provided for all variables upon extraction" if isinstance(self._base_var, str): self.base_var = load_variable( self._base_var, cohort=cohort, tmin=self.tmin, tmax=self.tmax ) else: self._base_var.tmin = self.tmin self._base_var.tmax = self.tmax self.base_var = self._base_var self.base_data = self.base_var.extract(cohort=cohort) self.data = self._extract_aggregation(cohort) self._apply_cleaning() self._call_var_function(cohort) return self.data def _extract_aggregation(self, cohort: Cohort) -> pl.DataFrame: """Extract static variable using the base variable Args: cohort (Cohort): Cohort object. Returns: pl.DataFrame: Extracted variable. """ base_data = self.base_data.lazy() obs = cohort._obs.lazy() # Sort by recordtime base_data = base_data.sort("recordtime") # Handle attributes if present if "attributes" in base_data.collect_schema().names(): # Flatten JSON attributes into separate columns attrs_sample = base_data.select("attributes").limit(1).collect() if attrs_sample.height > 0 and attrs_sample["attributes"][0] is not None: # Get the first non-null attribute to determine structure sample_attrs = attrs_sample["attributes"][0] if isinstance(sample_attrs, dict) and sample_attrs: # Create expressions to extract each attribute attr_exprs = [ pl.col("attributes") .struct.field(key) .alias(f"attributes.{key}") for key in sample_attrs.keys() ] base_data = base_data.with_columns(attr_exprs).drop("attributes") # Apply where clause filtering if self.where: if self.where.startswith("!"): where_func, where_params, _ = helpers.parse_select(self.where) match where_func: case "isin": base_data = base_data.filter( pl.col(where_params[0]).is_in(where_params[1]) ) case "startswith": # Create OR expression for multiple prefixes if len(where_params[1]) == 1: base_data = base_data.filter( pl.col(where_params[0]).str.starts_with( where_params[1][0] ) ) else: starts_expr = pl.col(where_params[0]).str.starts_with( where_params[1][0] ) for prefix in where_params[1][1:]: starts_expr = starts_expr | pl.col( where_params[0] ).str.starts_with(prefix) base_data = base_data.filter(starts_expr) case "endswith": # Create OR expression for multiple suffixes if len(where_params[1]) == 1: base_data = base_data.filter( pl.col(where_params[0]).str.ends_with( where_params[1][0] ) ) else: ends_expr = pl.col(where_params[0]).str.ends_with( where_params[1][0] ) for suffix in where_params[1][1:]: ends_expr = ends_expr | pl.col( where_params[0] ).str.ends_with(suffix) base_data = base_data.filter(ends_expr) case _: raise ValueError( f"Unsupported where function: {where_func}. Supported functions are: isin, startswith, endswith" ) else: # Convert pandas-style where clause to polars expression polars_where = helpers.sql_to_polars(self.where) base_data = base_data.filter(eval(polars_where)) # Parse time arguments obs = utils.pl_parse_time_args(obs, tmin=self.tmin, tmax=self.tmax) # Join with observation times base_data = base_data.join( obs.select([cohort.primary_key, "tmin", "tmax"]), on=cohort.primary_key, how="left", ) logger.debug( f"Base data shape (before filtering): {base_data.select(pl.len()).collect().item()}" ) base_data_time = base_data.filter( pl.col("recordtime").is_between("tmin", "tmax") ) logger.debug( f"Base data shape (after filtering tmin and tmax): {base_data_time.select(pl.len()).collect().item()}" ) params = self.agg_params # Perform aggregation using Polars syntax match self.agg_func.lower(): case "first": res = base_data_time.group_by(cohort.primary_key).agg( pl.col(self.select_cols).first() ) case "last": res = base_data_time.group_by(cohort.primary_key).agg( pl.col(self.select_cols).last() ) case "mean": res = base_data_time.group_by(cohort.primary_key).agg( pl.col(self.select_cols).mean() ) case "median": res = base_data_time.group_by(cohort.primary_key).agg( pl.col(self.select_cols).median() ) case "min": res = base_data_time.group_by(cohort.primary_key).agg( pl.col(self.select_cols).min() ) case "max": res = base_data_time.group_by(cohort.primary_key).agg( pl.col(self.select_cols).max() ) case "std": res = base_data_time.group_by(cohort.primary_key).agg( pl.col(self.select_cols).std() ) case "perc": quantile = float(params[0]) / 100 res = base_data_time.group_by(cohort.primary_key).agg( pl.col(self.select_cols).quantile(quantile) ) case "sum": res = base_data_time.group_by(cohort.primary_key).agg( pl.col(self.select_cols).sum() ) case "count": res = base_data_time.group_by(cohort.primary_key).agg( pl.len().alias(self.select_cols[0]) ) case "any": res = base_data_time.group_by(cohort.primary_key).agg( pl.col(self.select_cols).is_not_null().any().alias(col) for col in self.select_cols ) # Right join with all observations to include cases with no data res = ( obs.select(cohort.primary_key) .unique() .join(res, on=cohort.primary_key, how="left") .with_columns( [pl.col(col).fill_null(False) for col in self.select_cols] ) ) case "sum_interval": # Use unfiltered base_data instead of base_data_time here if "recordtime_end" not in base_data.columns: raise ValueError( "sum_interval requires 'recordtime_end' column in the base variable data" ) interval_data = ( base_data.with_columns( [ # Ensure recordtime_end is not null, use recordtime as fallback pl.when(pl.col("recordtime_end").is_null()) .then(pl.col("recordtime")) .otherwise(pl.col("recordtime_end")) .alias("recordtime_end_clean"), ] ) .with_columns( [ # Ensure interval is valid (end >= start), swap if necessary pl.when( pl.col("recordtime_end_clean") < pl.col("recordtime") ) .then( pl.col("recordtime") ) # Use recordtime as both start and end if invalid .otherwise(pl.col("recordtime_end_clean")) .alias("recordtime_end_clean"), ] ) .with_columns( [ # Calculate overlap start (max of interval start and tmin) pl.max_horizontal(["recordtime", "tmin"]).alias( "overlap_start" ), # Calculate overlap end (min of interval end and tmax) pl.min_horizontal(["recordtime_end_clean", "tmax"]).alias( "overlap_end" ), ] ) ) # Filter to only intervals that have some overlap interval_data = interval_data.filter( pl.col("overlap_start") <= pl.col("overlap_end") ) # Calculate overlap duration and total interval duration interval_data = interval_data.with_columns( [ # Overlap duration in seconds (pl.col("overlap_end") - pl.col("overlap_start")) .dt.total_seconds() .alias("overlap_seconds"), # Total interval duration in seconds (pl.col("recordtime_end_clean") - pl.col("recordtime")) .dt.total_seconds() .alias("interval_seconds"), ] ) # Calculate proportional values for each selected column proportional_exprs = [] for col in self.select_cols: # Handle division by zero: if interval is instantaneous (0 duration), use full value proportional_expr = ( pl.when(pl.col("interval_seconds") <= 0) .then(pl.col(col)) # Full value for instantaneous intervals .otherwise( pl.col(col) * pl.col("overlap_seconds") / pl.col("interval_seconds") ) .alias(f"{col}_proportional") ) proportional_exprs.append(proportional_expr) interval_data = interval_data.with_columns(proportional_exprs) # Sum the proportional values by primary key sum_exprs = [ pl.col(f"{col}_proportional").sum().alias(col) for col in self.select_cols ] res = interval_data.group_by(cohort.primary_key).agg(sum_exprs) # Ensure all cases are included, even those with no overlapping intervals # Join with all observations to include cases with zero sum res = ( obs.select(cohort.primary_key) .unique() .join(res, on=cohort.primary_key, how="left") .with_columns( [pl.col(col).fill_null(0.0) for col in self.select_cols] ) ) logger.debug( f"sum_interval processed {interval_data.select(pl.len()).collect().item()} interval records" ) case "closest": to_col = params[0] tdelta = params[1] if len(params) > 1 else None plusminus = params[2] if len(params) > 2 else "52w" if " " in plusminus: pm_before, pm_after = plusminus.split(" ") else: pm_before = pm_after = plusminus # Convert time deltas to total seconds for tolerance pm_before_secs = pd.to_timedelta(pm_before).total_seconds() pm_after_secs = pd.to_timedelta(pm_after).total_seconds() tdelta_secs = pd.to_timedelta(tdelta).total_seconds() if tdelta else 0.0 # Create target times dataframe target_times = obs.select( [ cohort.primary_key, ( pl.col(to_col).add(pl.duration(seconds=int(tdelta_secs))) ).alias("target_time"), ] ) # Use merge_asof for efficient time-based matching # For asymmetric tolerances, we need to handle this differently if pm_before_secs == pm_after_secs: # Symmetric tolerance - use merge_asof directly res = target_times.join_asof( base_data_time.select( [cohort.primary_key, "recordtime"] + self.select_cols ), left_on="target_time", right_on="recordtime", by=cohort.primary_key, tolerance=str(int(pm_before_secs)) + "s", strategy="nearest", ).drop("target_time") else: # Asymmetric tolerance - filter then use merge_asof # Create time bounds target_times = target_times.with_columns( [ ( pl.col("target_time").sub( pl.duration(seconds=int(pm_before_secs)) ) ).alias("time_min"), ( pl.col("target_time").add( pl.duration(seconds=int(pm_after_secs)) ) ).alias("time_max"), ] ) # Filter data within the asymmetric window using a join filtered_data = base_data_time.join( target_times.select( [cohort.primary_key, "time_min", "time_max", "target_time"] ), on=cohort.primary_key, how="inner", ).filter( (pl.col("recordtime").ge(pl.col("time_min"))) & (pl.col("recordtime").le(pl.col("time_max"))) ) # Now use merge_asof to find the closest match within the filtered data res = target_times.join_asof( filtered_data.select( [cohort.primary_key, "recordtime"] + self.select_cols ), left_on="target_time", right_on="recordtime", by=cohort.primary_key, strategy="nearest", ).drop(["target_time", "time_min", "time_max"]) case _: raise ValueError(f"Unsupported aggregation function: {self.agg_func}") logger.info( f"Extracted {self.var_name}.\nColumns: {res.collect_schema().names()}" ) res = res.select([cohort.primary_key] + self.select_cols) # Apply column naming # TODO: Remove if this is already handled by Cohort._save_variable if len(self.select_cols) > 1: rename_map = {col: f"{self.var_name}_{col}" for col in self.select_cols} res = res.rename(rename_map) else: res = res.rename({self.select_cols[0]: self.var_name}) return res.collect() class DerivedStatic(BaseVariable): """DerivedStatic: These variables are derivations on existing columns in the cohort.obs dataframe based on the expression argument. Args: var_name: Name of the variable. requires: List of required variables. expression: Expression to extract the variable. tmin: Minimum time for the extraction. tmax: Maximum time for the extraction. Note that DerivedStatic variables are executed on the cohort.obs dataframe and must reference existing columns in cohort.obs. For DerivedStatic variables, you may either provide an expression or a custom function in variables.py. Use expressions where possible, but custom functions if you require more complex logic. Examples: >>> DerivedStatic( ... var_name="inhospital_death", ... requires=["hospital_discharge", "death_timestamp"], ... expression="hospital_discharge <= death_timestamp" ... ) >>> DerivedStatic( ... var_name="any_va_ecmo_icu", ... requires=["ecmo_va_icu_ops", "ecmo_va_icu"], ... expression=(ecmo_va_icu_ops | ecmo_va_icu) ... ) """ def __init__( self, var_name, requires: RequirementsIterable = [], expression: str | None = None, tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, dynamic: bool = False, cleaning: CleaningDict | None = None, ): assert ( dynamic is False ), "DerivedStatic cannot be dynamic, please verify the configuration." super().__init__( var_name=var_name, dynamic=False, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, ) self.expression = expression self.computed_expression = False def extract(self, cohort: "Cohort") -> pd.DataFrame: assert self.tmin is not None, "tmin must be provided upon extraction" assert self.tmax is not None, "tmax must be provided upon extraction" self._get_required_vars(cohort) self.data = self._compute_expression(cohort) called_var_function = self._call_var_function(cohort, convert_polars=True) assert ( self.computed_expression or called_var_function ), "No expression or variable function found." assert ( self.var_name in self.data.columns ), f"DerivedStatic variable {self.var_name} not found in extracted data. ({self.data.columns})" # Get all columns that are either the primary key, the variable name itself, or start with the variable name cols_to_keep = [ col for col in self.data.columns if col == cohort.primary_key or col == self.var_name or (isinstance(col, str) and col.startswith(f"{self.var_name}_")) ] self.data = self.data.select(cols_to_keep) self._apply_cleaning() return self.data def _compute_expression(self, cohort): """Compute the expression using the required variables. Args: cohort (Cohort): Cohort object. Returns: pl.DataFrame: Extracted variable. (Returns the original obs dataframe if no expression is provided.) """ obs = cohort._obs.clone() for var in self.required_vars.values(): if not var.dynamic: obs = obs.join( var.data.select([cohort.primary_key, var.var_name]), on=cohort.primary_key, how="left", ) if not self.expression: return obs else: expr = eval(self.expression) obs = obs.with_columns(expr.alias(self.var_name)) self.computed_expression = True return obs class DerivedDynamic(BaseVariable): """Derived dynamic variables are extracted using a custom function. Args: var_name: Name of the variable. requires: List of required variables. cleaning: Cleaning parameters ({column_name: {low: int, high: int}}) tmin: Minimum time for the extraction. tmax: Maximum time for the extraction. py: Custom function. py_ready_polars: Whether the custom function is already prepared for polars. Default is False (input and output are pandas dataframes). Examples: >>> def var_func(var, cohort): ... # Note that this simplified example only works if blood_pao2_arterial and vent_fio2 are of the same length (which is proabably not the case). ... return var.with_columns( ... (var.required_vars["blood_pao2_arterial"] / var.required_vars["vent_fio2"]).alias("pf_ratio") ... ) >>> DerivedDynamic( ... var_name="pf_ratio", ... requires=["blood_pao2_arterial", "vent_fio2"], ... py=var_func, ... py_ready_polars=True) """ def __init__( self, var_name, requires: RequirementsIterable, cleaning: CleaningDict | None = None, tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, dynamic: bool = True, ): assert ( dynamic is True ), "DerivedDynamic must be dynamic, please verify the configuration." super().__init__( var_name, dynamic=True, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, ) def extract(self, cohort: "Cohort") -> pl.DataFrame: """Extract the variable. Args: cohort (Cohort): Cohort object. Returns: pd.DataFrame: Extracted variable. """ assert ( self.tmin is not None ), "tmin must be provided for all variables upon extraction" assert ( self.tmax is not None ), "tmax must be provided for all variables upon extraction" self._get_required_vars(cohort) self._call_var_function(cohort, convert_polars=True) self._apply_cleaning() self._add_relative_times(cohort) self._unify_and_order_columns(cohort.primary_key) return self.data