from __future__ import annotations
import os
import pickle
import pandas as pd
import polars as pl
from corr_vars import logger
from corr_vars.core.variable import NativeVariable
from corr_vars.core.variable import Variable as BaseVariable
from corr_vars.sources.var_loader import load_variable, MultiSourceVariable
from corr_vars.sources.cub_hdp import helpers
import corr_vars.utils as utils
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from corr_vars.core import Cohort
from corr_vars.definitions import (
TimeBoundColumn,
RequirementsIterable,
CleaningDict,
VariableCallable,
)
[docs]
def Variable(*args, **kwargs) -> BaseVariable:
"""Factory function to create the correct variable type based on the arguments."""
logger.info("Creating variable with:")
if args:
prefix = "└── " if not kwargs else "├── "
logger.info(f"{prefix}args: {args}")
if kwargs:
prefix = "└── "
logger.info(f"{prefix}kwargs:")
utils.log_dict(logger=logger, dictionary=kwargs, indent=4)
# Legacy naming error
if kwargs.get("variable"):
kwargs["base_var"] = kwargs.pop("variable")
try:
var_type = kwargs.pop("type", None)
# Overwrite dynamic attribute
if var_type in [
"native_dynamic",
"derived_dynamic",
]:
kwargs["dynamic"] = True
# Add complex attribute
if var_type == "complex":
kwargs["complex"] = True
except KeyError:
logger.error("CubHDP Variable requires the 'type' field.")
if var_type in ["native_dynamic", "complex"]:
return NativeDynamic(*args, **kwargs)
elif var_type == "derived_dynamic":
return DerivedDynamic(*args, **kwargs)
elif var_type == "native_static":
return NativeStatic(*args, **kwargs)
elif var_type == "derived_static":
return DerivedStatic(*args, **kwargs)
else:
raise KeyError(f"Type {var_type} does not exist.")
class NativeDynamic(NativeVariable):
"""
Variable class for the cub_hdp source.
Args:
var_name: Name of the variable
dynamic: Whether the variable is dynamic
requires: List of required variables
tmin: Minimum time point
tmax: Maximum time point
py: Python function to apply to the variable
py_ready_polars: Whether the variable is already polars ready
cleaning: Dictionary of cleaning functions
allow_caching: Whether to allow caching
table: Name of the table to extract from
where: WHERE clause to filter the table
value_dtype: Data type of the value column
columns: List of columns to extract
complex: Whether the variable is complex (Will skip the database extraction -> Must specify query in py)
Returns:
None
Note:
Only use this class for defining custom variables. For CORR-specified variables, use cohort.add_variable() directly.
"""
def __init__(
self,
# core.variable.Variable
var_name: str,
dynamic: bool,
requires: RequirementsIterable = [],
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
py: VariableCallable | None = None,
py_ready_polars: bool = False,
cleaning: CleaningDict | None = None,
# core.variable.NativeVariable(Variable)
allow_caching: bool = True,
# sources.cub_hdp.extract.Variable(NativeVariable)
table: str | None = None,
where: str | None = None,
value_dtype: str | None = None,
columns: dict[str, str | dict[str, str]] | None = None,
complex: bool = False,
) -> None:
super().__init__(
var_name=var_name,
dynamic=dynamic,
requires=requires,
tmin=tmin,
tmax=tmax,
py=py,
py_ready_polars=py_ready_polars,
cleaning=cleaning,
allow_caching=allow_caching,
)
self.table = table
self.where = where
self.value_dtype = value_dtype
self.columns = columns
self.complex = complex
# Apply timefilter always for non-complex variables
self._timefilter_always = not self.complex
def extract(self, cohort: Cohort) -> pl.DataFrame:
assert self.tmin is not None, "tmin must be provided for all variables"
assert self.tmax is not None, "tmax must be provided for all variables"
self.data = self._get_from_cache(cohort)
if self.data is None:
self._get_required_vars(cohort)
# NativeDynamic variables
if not self.complex:
self.data = self._extract_from_db(cohort)
function_called = self._call_var_function(cohort, convert_polars=True)
if self.data is None:
if self.complex and not function_called:
raise ValueError(
"Complex variables need to extract data inside a custom variable function."
)
else:
raise ValueError("Variable extraction resulted in no data.")
self.data = self.build_attributes(self.data)
self._write_to_cache(cohort)
if self.dynamic:
# Compled variables are time-filtered only if case_tmin and case_tmax are prepared in the variable function
self._timefilter(cohort, always=self._timefilter_always)
self._apply_cleaning()
self._add_relative_times(cohort)
self._unify_and_order_columns(cohort.primary_key)
return self.data
def _extract_from_db(self, cohort: Cohort) -> pl.DataFrame:
"""Extract data from the database."""
if self.complex:
raise RuntimeError(
"Complex NativeDynamic should not call _extract_from_db."
)
with helpers.temporary_table(cohort) as temp_tbl:
query, table_info, keep_cols = helpers.build_base_query(
self, helpers._get_cub_hdp_database(cohort), temp_tbl
)
conn = helpers.reinit_conn(cohort)
result_df: pl.LazyFrame = conn.lf_read_sql(
query=query,
datetime_cols=[
col
for col in table_info
if "recordtime" in col or "storetime" in col
],
keep_cols=keep_cols,
)
logger.info(f"Extracted {result_df.select(pl.len()).collect().item()} rows")
result_df = self.build_attributes(result_df)
return result_df.collect()
def _get_from_cache(self, cohort: Cohort) -> pl.DataFrame | None:
path = cohort.tmpdir_manager.create_tmpdir_variable_path(
var_name=self.var_name, extension=None
)
if self.allow_caching and os.path.exists(f"{path}.parquet"):
# Before loading, ensure that cacheinfo exists and all case_ids from cohort.obs are present
if not os.path.exists(f"{path}.cacheinfo"):
return None
with open(f"{path}.cacheinfo", "rb") as f:
cache_case_ids = pickle.load(f)
cohort_case_ids = set(
cohort.obs.select("case_id").unique().to_series().to_list()
)
# Return none if cohort.obs contains uncached case_ids
if not cohort_case_ids.issubset(cache_case_ids):
return None
self.data = pl.read_parquet(f"{path}.parquet")
logger.info(
f"Loaded {self.var_name} from tmpdir ({self.data.shape[0]} rows)"
)
if "case_id" in self.data.columns:
self.data = self.data.filter(pl.col("case_id").is_in(cohort_case_ids))
logger.info(f"Found {self.data.shape[0]} rows matching cohort.obs")
if "case_tmin" in self.data.columns and "case_tmax" in self.data.columns:
# For complex variables that have precomputed case_tmin and case_tmax
# The timefilter needs to be reapplied - we will set timefilter_always to True
# For this, we need to drop the case_tmin, case_tmax non case_id primary keys
drop_cols = ["case_tmin", "case_tmax"] + (
[cohort.primary_key] if cohort.primary_key != "case_id" else []
)
self.data = self.data.drop(drop_cols, strict=False)
self._timefilter_always = True
return self.data
return None
def _write_to_cache(self, cohort: Cohort) -> None:
if self.allow_caching and self.data is not None:
path = cohort.tmpdir_manager.create_tmpdir_variable_path(
var_name=self.var_name, extension=None
)
self.data.write_parquet(f"{path}.parquet")
case_ids = cohort.obs.select("case_id").unique().to_series().to_list()
if os.path.exists(f"{path}.cacheinfo"):
os.remove(f"{path}.cacheinfo")
# Write cacheinfo
with open(f"{path}.cacheinfo", "wb") as f:
pickle.dump(case_ids, f)
logger.info(f"Wrote {self.var_name} to tmpdir ({self.data.shape[0]} rows)")
class NativeStatic(BaseVariable):
"""Aggregated variables represent simple aggregations of dynamic variables.
Args:
var_name (str): Name of the variable.
select (str): Select clause specifying aggregation function and columns.
base_var (str): Name of the base variable (must be a native_dynamic variable).
where (str, optional): Optional WHERE clause (in format for polars).
tmin (str, optional): Minimum time for the extraction.
tmax (str, optional): Maximum time for the extraction.
The select argument supports several aggregation functions:
- ``!first [columns]``: Returns the first row within this case
>>> "!first value" # Single column
>>> "!first value, recordtime" # Multiple columns
- ``!last [columns]``: Returns the last row within this case
>>> "!last value"
>>> "!last value, recordtime"
- ``!any``: Returns True if any value exists
>>> "!any"
>>> "!any value"
- ``!sum [column]``: Calculates sum of values
>>> "!sum value"
- ``!closest(to_column, timedelta, plusminus) [columns]``: Selects value closest to specified column
Args:
to_column: Column to compare "recordtime" against
timedelta: Time to add to "to_column" for comparison
plusminus: Allowed time mismatch (can specify different before/after with space)
>>> "!closest(hospital_admission) value, recordtime" # Closest to admission
>>> "!closest(hospital_admission, 0, 2h 3h) value" # 2h before to 3h after
>>> "!closest(first_intubation_dtime, 6h, 2h) value" # 6h after intubation ±2h
- ``!mean [column]``: Calculates mean value
>>> "!mean value"
- ``!median [column]``: Calculates median value
>>> "!median value"
- ``!perc(quantile) [column]``: Calculates specified percentile
>>> "!perc(75) value" # 75th percentile
The where argument supports Pandas-style boolean expressions. These are evaluated in the context of the base variable by pd.eval().
Where also supports magic commands (starting with !) to filter the data. Supported commands are:
- ``!isin(column, [values])``: Filters rows where the value in column is in values
- ``!startswith(column, [values])``: Filters rows where the value in column starts with any of the values
- ``!endswith(column, [values])``: Filters rows where the value in column ends with any of the values
"""
base_var: BaseVariable | MultiSourceVariable
def __init__(
self,
var_name: str,
select: str,
base_var: str | BaseVariable,
where: str | None = None,
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
cleaning: CleaningDict | None = None,
) -> None:
super().__init__(
var_name,
dynamic=False,
tmin=tmin,
tmax=tmax,
cleaning=cleaning,
)
self.select = select
logger.debug(
f"NativeStatic: SELECT {self.select} FROM {base_var} WHERE {where}"
)
self.where = where
self._base_var = base_var
self.agg_func, self.agg_params, self.select_cols = helpers.parse_select(
self.select
)
if len(self.select_cols) == 0:
self.select_cols = ["value"]
def extract(self, cohort: Cohort) -> pl.DataFrame:
"""Extract the variable. You do not need to call this yourself, as it is called internally when you add the variable to a cohort.
However, you may call it directly to obtain variable data independently of the cohort. You still need a cohort object for case ids and other metadata.
Args:
cohort: Cohort object.
Returns:
Extracted variable.
After extraction, you may also access the data as ``Variable.data``.
Examples:
>>> var = NativeStatic(
... var_name="first_sodium_recordtime",
... select="!first recordtime",
... base_var="blood_sodium",
... tmin="hospital_admission"
... )
>>> var.extract(cohort) # With var.extract(), the data will not be added to the cohort.
>>> var.data # You can access the data as a polars dataframe.
"""
assert (
self.tmin is not None
), "tmin must be provided for all variables upon extraction"
assert (
self.tmax is not None
), "tmax must be provided for all variables upon extraction"
if isinstance(self._base_var, str):
self.base_var = load_variable(
self._base_var, cohort=cohort, tmin=self.tmin, tmax=self.tmax
)
else:
self._base_var.tmin = self.tmin
self._base_var.tmax = self.tmax
self.base_var = self._base_var
self.base_data = self.base_var.extract(cohort=cohort)
self.data = self._extract_aggregation(cohort)
self._apply_cleaning()
self._call_var_function(cohort)
return self.data
def _extract_aggregation(self, cohort: Cohort) -> pl.DataFrame:
"""Extract static variable using the base variable
Args:
cohort (Cohort): Cohort object.
Returns:
pl.DataFrame: Extracted variable.
"""
base_data = self.base_data.lazy()
obs = cohort._obs.lazy()
# Sort by recordtime
base_data = base_data.sort("recordtime")
# Handle attributes if present
if "attributes" in base_data.collect_schema().names():
# Flatten JSON attributes into separate columns
attrs_sample = base_data.select("attributes").limit(1).collect()
if attrs_sample.height > 0 and attrs_sample["attributes"][0] is not None:
# Get the first non-null attribute to determine structure
sample_attrs = attrs_sample["attributes"][0]
if isinstance(sample_attrs, dict) and sample_attrs:
# Create expressions to extract each attribute
attr_exprs = [
pl.col("attributes")
.struct.field(key)
.alias(f"attributes.{key}")
for key in sample_attrs.keys()
]
base_data = base_data.with_columns(attr_exprs).drop("attributes")
# Apply where clause filtering
if self.where:
if self.where.startswith("!"):
where_func, where_params, _ = helpers.parse_select(self.where)
match where_func:
case "isin":
base_data = base_data.filter(
pl.col(where_params[0]).is_in(where_params[1])
)
case "startswith":
# Create OR expression for multiple prefixes
if len(where_params[1]) == 1:
base_data = base_data.filter(
pl.col(where_params[0]).str.starts_with(
where_params[1][0]
)
)
else:
starts_expr = pl.col(where_params[0]).str.starts_with(
where_params[1][0]
)
for prefix in where_params[1][1:]:
starts_expr = starts_expr | pl.col(
where_params[0]
).str.starts_with(prefix)
base_data = base_data.filter(starts_expr)
case "endswith":
# Create OR expression for multiple suffixes
if len(where_params[1]) == 1:
base_data = base_data.filter(
pl.col(where_params[0]).str.ends_with(
where_params[1][0]
)
)
else:
ends_expr = pl.col(where_params[0]).str.ends_with(
where_params[1][0]
)
for suffix in where_params[1][1:]:
ends_expr = ends_expr | pl.col(
where_params[0]
).str.ends_with(suffix)
base_data = base_data.filter(ends_expr)
case _:
raise ValueError(
f"Unsupported where function: {where_func}. Supported functions are: isin, startswith, endswith"
)
else:
# Convert pandas-style where clause to polars expression
polars_where = helpers.sql_to_polars(self.where)
base_data = base_data.filter(eval(polars_where))
# Parse time arguments
obs = utils.pl_parse_time_args(obs, tmin=self.tmin, tmax=self.tmax)
# Join with observation times
base_data = base_data.join(
obs.select([cohort.primary_key, "tmin", "tmax"]),
on=cohort.primary_key,
how="left",
)
logger.debug(
f"Base data shape (before filtering): {base_data.select(pl.len()).collect().item()}"
)
base_data_time = base_data.filter(
pl.col("recordtime").is_between("tmin", "tmax")
)
logger.debug(
f"Base data shape (after filtering tmin and tmax): {base_data_time.select(pl.len()).collect().item()}"
)
params = self.agg_params
# Perform aggregation using Polars syntax
match self.agg_func.lower():
case "first":
res = base_data_time.group_by(cohort.primary_key).agg(
pl.col(self.select_cols).first()
)
case "last":
res = base_data_time.group_by(cohort.primary_key).agg(
pl.col(self.select_cols).last()
)
case "mean":
res = base_data_time.group_by(cohort.primary_key).agg(
pl.col(self.select_cols).mean()
)
case "median":
res = base_data_time.group_by(cohort.primary_key).agg(
pl.col(self.select_cols).median()
)
case "min":
res = base_data_time.group_by(cohort.primary_key).agg(
pl.col(self.select_cols).min()
)
case "max":
res = base_data_time.group_by(cohort.primary_key).agg(
pl.col(self.select_cols).max()
)
case "std":
res = base_data_time.group_by(cohort.primary_key).agg(
pl.col(self.select_cols).std()
)
case "perc":
quantile = float(params[0]) / 100
res = base_data_time.group_by(cohort.primary_key).agg(
pl.col(self.select_cols).quantile(quantile)
)
case "sum":
res = base_data_time.group_by(cohort.primary_key).agg(
pl.col(self.select_cols).sum()
)
case "count":
res = base_data_time.group_by(cohort.primary_key).agg(
pl.len().alias(self.select_cols[0])
)
case "any":
res = base_data_time.group_by(cohort.primary_key).agg(
pl.col(self.select_cols).is_not_null().any().alias(col)
for col in self.select_cols
)
# Right join with all observations to include cases with no data
res = (
obs.select(cohort.primary_key)
.unique()
.join(res, on=cohort.primary_key, how="left")
.with_columns(
[pl.col(col).fill_null(False) for col in self.select_cols]
)
)
case "sum_interval":
# Use unfiltered base_data instead of base_data_time here
if "recordtime_end" not in base_data.columns:
raise ValueError(
"sum_interval requires 'recordtime_end' column in the base variable data"
)
interval_data = (
base_data.with_columns(
[
# Ensure recordtime_end is not null, use recordtime as fallback
pl.when(pl.col("recordtime_end").is_null())
.then(pl.col("recordtime"))
.otherwise(pl.col("recordtime_end"))
.alias("recordtime_end_clean"),
]
)
.with_columns(
[
# Ensure interval is valid (end >= start), swap if necessary
pl.when(
pl.col("recordtime_end_clean") < pl.col("recordtime")
)
.then(
pl.col("recordtime")
) # Use recordtime as both start and end if invalid
.otherwise(pl.col("recordtime_end_clean"))
.alias("recordtime_end_clean"),
]
)
.with_columns(
[
# Calculate overlap start (max of interval start and tmin)
pl.max_horizontal(["recordtime", "tmin"]).alias(
"overlap_start"
),
# Calculate overlap end (min of interval end and tmax)
pl.min_horizontal(["recordtime_end_clean", "tmax"]).alias(
"overlap_end"
),
]
)
)
# Filter to only intervals that have some overlap
interval_data = interval_data.filter(
pl.col("overlap_start") <= pl.col("overlap_end")
)
# Calculate overlap duration and total interval duration
interval_data = interval_data.with_columns(
[
# Overlap duration in seconds
(pl.col("overlap_end") - pl.col("overlap_start"))
.dt.total_seconds()
.alias("overlap_seconds"),
# Total interval duration in seconds
(pl.col("recordtime_end_clean") - pl.col("recordtime"))
.dt.total_seconds()
.alias("interval_seconds"),
]
)
# Calculate proportional values for each selected column
proportional_exprs = []
for col in self.select_cols:
# Handle division by zero: if interval is instantaneous (0 duration), use full value
proportional_expr = (
pl.when(pl.col("interval_seconds") <= 0)
.then(pl.col(col)) # Full value for instantaneous intervals
.otherwise(
pl.col(col)
* pl.col("overlap_seconds")
/ pl.col("interval_seconds")
)
.alias(f"{col}_proportional")
)
proportional_exprs.append(proportional_expr)
interval_data = interval_data.with_columns(proportional_exprs)
# Sum the proportional values by primary key
sum_exprs = [
pl.col(f"{col}_proportional").sum().alias(col)
for col in self.select_cols
]
res = interval_data.group_by(cohort.primary_key).agg(sum_exprs)
# Ensure all cases are included, even those with no overlapping intervals
# Join with all observations to include cases with zero sum
res = (
obs.select(cohort.primary_key)
.unique()
.join(res, on=cohort.primary_key, how="left")
.with_columns(
[pl.col(col).fill_null(0.0) for col in self.select_cols]
)
)
logger.debug(
f"sum_interval processed {interval_data.select(pl.len()).collect().item()} interval records"
)
case "closest":
to_col = params[0]
tdelta = params[1] if len(params) > 1 else None
plusminus = params[2] if len(params) > 2 else "52w"
if " " in plusminus:
pm_before, pm_after = plusminus.split(" ")
else:
pm_before = pm_after = plusminus
# Convert time deltas to total seconds for tolerance
pm_before_secs = pd.to_timedelta(pm_before).total_seconds()
pm_after_secs = pd.to_timedelta(pm_after).total_seconds()
tdelta_secs = pd.to_timedelta(tdelta).total_seconds() if tdelta else 0.0
# Create target times dataframe
target_times = obs.select(
[
cohort.primary_key,
(
pl.col(to_col).add(pl.duration(seconds=int(tdelta_secs)))
).alias("target_time"),
]
)
# Use merge_asof for efficient time-based matching
# For asymmetric tolerances, we need to handle this differently
if pm_before_secs == pm_after_secs:
# Symmetric tolerance - use merge_asof directly
res = target_times.join_asof(
base_data_time.select(
[cohort.primary_key, "recordtime"] + self.select_cols
),
left_on="target_time",
right_on="recordtime",
by=cohort.primary_key,
tolerance=str(int(pm_before_secs)) + "s",
strategy="nearest",
).drop("target_time")
else:
# Asymmetric tolerance - filter then use merge_asof
# Create time bounds
target_times = target_times.with_columns(
[
(
pl.col("target_time").sub(
pl.duration(seconds=int(pm_before_secs))
)
).alias("time_min"),
(
pl.col("target_time").add(
pl.duration(seconds=int(pm_after_secs))
)
).alias("time_max"),
]
)
# Filter data within the asymmetric window using a join
filtered_data = base_data_time.join(
target_times.select(
[cohort.primary_key, "time_min", "time_max", "target_time"]
),
on=cohort.primary_key,
how="inner",
).filter(
(pl.col("recordtime").ge(pl.col("time_min")))
& (pl.col("recordtime").le(pl.col("time_max")))
)
# Now use merge_asof to find the closest match within the filtered data
res = target_times.join_asof(
filtered_data.select(
[cohort.primary_key, "recordtime"] + self.select_cols
),
left_on="target_time",
right_on="recordtime",
by=cohort.primary_key,
strategy="nearest",
).drop(["target_time", "time_min", "time_max"])
case _:
raise ValueError(f"Unsupported aggregation function: {self.agg_func}")
logger.info(
f"Extracted {self.var_name}.\nColumns: {res.collect_schema().names()}"
)
res = res.select([cohort.primary_key] + self.select_cols)
# Apply column naming
# TODO: Remove if this is already handled by Cohort._save_variable
if len(self.select_cols) > 1:
rename_map = {col: f"{self.var_name}_{col}" for col in self.select_cols}
res = res.rename(rename_map)
else:
res = res.rename({self.select_cols[0]: self.var_name})
return res.collect()
class DerivedStatic(BaseVariable):
"""DerivedStatic: These variables are derivations on existing columns in the cohort.obs dataframe based on the expression argument.
Args:
var_name: Name of the variable.
requires: List of required variables.
expression: Expression to extract the variable.
tmin: Minimum time for the extraction.
tmax: Maximum time for the extraction.
Note that DerivedStatic variables are executed on the cohort.obs dataframe and must reference existing columns in cohort.obs.
For DerivedStatic variables, you may either provide an expression or a custom function in variables.py.
Use expressions where possible, but custom functions if you require more complex logic.
Examples:
>>> DerivedStatic(
... var_name="inhospital_death",
... requires=["hospital_discharge", "death_timestamp"],
... expression="hospital_discharge <= death_timestamp"
... )
>>> DerivedStatic(
... var_name="any_va_ecmo_icu",
... requires=["ecmo_va_icu_ops", "ecmo_va_icu"],
... expression=(ecmo_va_icu_ops | ecmo_va_icu)
... )
"""
def __init__(
self,
var_name,
requires: RequirementsIterable = [],
expression: str | None = None,
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
py: VariableCallable | None = None,
py_ready_polars: bool = False,
dynamic: bool = False,
cleaning: CleaningDict | None = None,
):
assert (
dynamic is False
), "DerivedStatic cannot be dynamic, please verify the configuration."
super().__init__(
var_name=var_name,
dynamic=False,
requires=requires,
tmin=tmin,
tmax=tmax,
py=py,
py_ready_polars=py_ready_polars,
cleaning=cleaning,
)
self.expression = expression
self.computed_expression = False
def extract(self, cohort: "Cohort") -> pd.DataFrame:
assert self.tmin is not None, "tmin must be provided upon extraction"
assert self.tmax is not None, "tmax must be provided upon extraction"
self._get_required_vars(cohort)
self.data = self._compute_expression(cohort)
called_var_function = self._call_var_function(cohort, convert_polars=True)
assert (
self.computed_expression or called_var_function
), "No expression or variable function found."
assert (
self.var_name in self.data.columns
), f"DerivedStatic variable {self.var_name} not found in extracted data. ({self.data.columns})"
# Get all columns that are either the primary key, the variable name itself, or start with the variable name
cols_to_keep = [
col
for col in self.data.columns
if col == cohort.primary_key
or col == self.var_name
or (isinstance(col, str) and col.startswith(f"{self.var_name}_"))
]
self.data = self.data.select(cols_to_keep)
self._apply_cleaning()
return self.data
def _compute_expression(self, cohort):
"""Compute the expression using the required variables.
Args:
cohort (Cohort): Cohort object.
Returns:
pl.DataFrame: Extracted variable. (Returns the original obs dataframe if no expression is provided.)
"""
obs = cohort._obs.clone()
for var in self.required_vars.values():
if not var.dynamic:
obs = obs.join(
var.data.select([cohort.primary_key, var.var_name]),
on=cohort.primary_key,
how="left",
)
if not self.expression:
return obs
else:
expr = eval(self.expression)
obs = obs.with_columns(expr.alias(self.var_name))
self.computed_expression = True
return obs
class DerivedDynamic(BaseVariable):
"""Derived dynamic variables are extracted using a custom function.
Args:
var_name: Name of the variable.
requires: List of required variables.
cleaning: Cleaning parameters ({column_name: {low: int, high: int}})
tmin: Minimum time for the extraction.
tmax: Maximum time for the extraction.
py: Custom function.
py_ready_polars: Whether the custom function is already prepared for polars. Default is False (input and output are pandas dataframes).
Examples:
>>> def var_func(var, cohort):
... # Note that this simplified example only works if blood_pao2_arterial and vent_fio2 are of the same length (which is proabably not the case).
... return var.with_columns(
... (var.required_vars["blood_pao2_arterial"] / var.required_vars["vent_fio2"]).alias("pf_ratio")
... )
>>> DerivedDynamic(
... var_name="pf_ratio",
... requires=["blood_pao2_arterial", "vent_fio2"],
... py=var_func,
... py_ready_polars=True)
"""
def __init__(
self,
var_name,
requires: RequirementsIterable,
cleaning: CleaningDict | None = None,
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
py: VariableCallable | None = None,
py_ready_polars: bool = False,
dynamic: bool = True,
):
assert (
dynamic is True
), "DerivedDynamic must be dynamic, please verify the configuration."
super().__init__(
var_name,
dynamic=True,
requires=requires,
tmin=tmin,
tmax=tmax,
py=py,
py_ready_polars=py_ready_polars,
cleaning=cleaning,
)
def extract(self, cohort: "Cohort") -> pl.DataFrame:
"""Extract the variable.
Args:
cohort (Cohort): Cohort object.
Returns:
pd.DataFrame: Extracted variable.
"""
assert (
self.tmin is not None
), "tmin must be provided for all variables upon extraction"
assert (
self.tmax is not None
), "tmax must be provided for all variables upon extraction"
self._get_required_vars(cohort)
self._call_var_function(cohort, convert_polars=True)
self._apply_cleaning()
self._add_relative_times(cohort)
self._unify_and_order_columns(cohort.primary_key)
return self.data