from __future__ import annotations
from datetime import datetime, timezone
import os
import pickle
import tarfile
import tempfile
import textwrap
import traceback
import warnings
import gc
import pandas as pd
import polars as pl
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
import zstandard as zstd
from corr_vars import logger
from corr_vars.core.change_tracker import ChangeTrackerPipeline, ChangeTrackerContext
from corr_vars.core.file_manager import TemporaryDirectoryManager
from corr_vars.core.variable import Variable
import corr_vars.sources as loader
from corr_vars.sources.var_loader import MultiSourceVariable
import corr_vars.utils as utils
from collections.abc import Iterator, MutableMapping, Hashable
from corr_vars.definitions import ObsLevel
from typing import Final, Literal, Any, overload, TYPE_CHECKING
if TYPE_CHECKING:
from corr_vars.definitions import (
ArrayLike,
ObsLevelNames,
SourceDictPartial,
SourceDict,
TimeBoundColumn,
)
from polars._typing import (
SingleIndexSelector,
SingleColSelector,
MultiColSelector,
MultiIndexSelector,
)
from tableone import TableOne
class CohortDataError(Exception):
"""Raised upon cohort data errors."""
pass
class VariableNotFoundError(Exception):
"""Raised when a variable is not found in the cohort."""
pass
class ObsmDict(MutableMapping[str, pl.DataFrame]):
"""Dictionary-like wrapper for obsm data providing convenient access and display.
Args:
data (dict): Dictionary of polars DataFrames containing dynamic variable data.
"""
def __init__(self, data: dict[str, pl.DataFrame]) -> None:
self._data = data
def __getitem__(self, key) -> pl.DataFrame:
try:
return self._data[key]
except KeyError:
available_vars = list(self._data.keys())
if available_vars:
error_msg = (
f"Variable '{key}' not found in cohort.obsm. "
f"Available variables: {', '.join(available_vars)}"
)
else:
error_msg = (
f"Variable '{key}' not found in cohort.obsm. "
f"No variables have been extracted yet. Use cohort.add_variable() to extract variables."
)
raise VariableNotFoundError(error_msg) from None
def __contains__(self, key):
return key in self._data
def __setitem__(self, key, value) -> None:
self._data[key] = value
def __delitem__(self, key) -> None:
del self._data[key]
def __iter__(self) -> Iterator[str]:
return iter(self._data)
def __len__(self) -> int:
return len(self._data)
def __repr__(self) -> str:
lines = [f"{k}: shape={df.shape}" for k, df in self._data.items()]
return "ObsmDict(\n " + "\n ".join(lines) + "\n)" if lines else "ObsmDict(Ø)"
[docs]
class Cohort:
"""Class to build a cohort in the CORR database.
Args:
obs_level (Literal["icu_stay", "hospital_stay", "procedure"]):
Observation level (default: "icu_stay").
sources (dict[str, dict]):
Dictionary of data sources to use for data extraction. Available options are "cub_hdp", "reprodicu", "co5".
Source configurations:
- cub_hdp: Accepted keys: database, password_file, merge_consecutive, filters, conn_args [remote_hostname (str), username (str)]
- reprodicu: Accepted keys: path, exclude_datasets
- co5: Not implemented yet. Data request is still pending.
Note: reprodicu does not yet implement variable extraction, only cohort data.
project_vars (dict):
Dictionary with local variable definitions (default: {}).
load_default_vars (bool):
Whether to load the default variables (default: True).
logger_args (dict):
Dictionary of Logging configurations [level (int), file_path (str), file_mode (str), verbose_fmt (bool), colored_output (bool), formatted_numbers (bool)] (default: {}).
Attributes:
obs (pl.DataFrame):
Static data for each observation. Contains one row per observation (e.g., ICU stay)
with columns for static variables like demographics and outcomes.
Example:
>>> cohort.obs
patient_id case_id icu_stay_id icu_admission icu_discharge sex ... inhospital_death
0 P001 C001 C001_1 2023-01-01 08:30:00 2023-01-03 12:00:00 M ... False
1 P001 C001 C001_2 2023-01-03 14:20:00 2023-01-05 16:30:00 M ... False
2 P002 C002 C002_1 2023-01-02 09:15:00 2023-01-04 10:30:00 F ... False
3 P003 C003 C003_1 2023-01-04 11:45:00 2023-01-07 13:20:00 F ... True
...
obsm (dict of pl.DataFrame):
Dynamic data stored as dictionary of DataFrames. Each DataFrame contains time-series
data for a variable with columns:
- recordtime: Timestamp of the measurement
- value: Value of the measurement
- recordtime_end: End time (only for duration-based variables like therapies)
- description: Additional information (e.g., medication names)
Example:
>>> cohort.obsm["blood_sodium"]
icu_stay_id recordtime value
0 C001_1 2023-01-01 09:30:00 138
1 C001_1 2023-01-02 10:15:00 141
2 C001_2 2023-01-03 15:00:00 137
3 C002_1 2023-01-02 10:00:00 142
4 C003_1 2023-01-04 12:30:00 139
...
Notes:
- For large cohorts, set ``load_default_vars=False`` to speed up the extraction. You can use
pre-extracted cohorts as starting points and load them using ``Cohort.load()``.
- Variables can be added using ``cohort.add_variable()``. Static variables will be added to ``obs``,
dynamic variables to ``obsm``.
- ``sources["cub_hdp"]["filters"]`` also allows a special shorthand "_dx" to extract the hospital admissions of the
last x months, useful for debugging/prototyping. For example use "_d2" to extract every admission
of the last 2 months.
Examples:
Create a new cohort:
>>> cohort = Cohort(
... obs_level="icu_stay",
... load_default_vars=False,
... sources={
... "cub_hdp": {
... "database": "db_hypercapnia_prepared",
... "password_file": True,
... "merge_consecutive": False},
... "reprodicu": {
... "path": "/data02/projects/reprodicubility/reprodICU/reprodICU_files"}
... })
Access static data:
>>> cohort.obs.select("age_on_admission") # Get age for all patients
>>> cohort.obs.filter(pl.col("sex") == "M") # Filter for male patients
Access time-series data:
>>> cohort.obsm["blood_sodium"] # Get all blood sodium measurements
>>> # Get blood sodium measurements for a specific observation
>>> cohort.obsm["blood_sodium"].filter(pl.col(cohort.primary_key) == "12345")
"""
# Data variables
_obs: pl.DataFrame
_obsm: dict[str, pl.DataFrame]
constant_vars: list[str]
# Observation level variables
obs_level: ObsLevel
primary_key: str
t_min: str
t_max: str
t_eligible: str
t_outcome: str
# Changetracker
_change_tracker: Final[ChangeTrackerPipeline]
# Configurations
logger_args: dict[str, Any]
sources: Final[SourceDict]
project_vars: dict[str, dict[str, Any]]
# Temporary directory variables
tmpdir_manager: Final[TemporaryDirectoryManager]
_current_tmpdir_path: str
# Flags
_from_file: bool
_data_load_time: datetime
# Properties (defined at the bottom of Cohort class)
# obs
# tmpdir_path
# stata
# tableone
def __init__(
self,
obs_level: ObsLevelNames = "icu_stay",
sources: SourceDictPartial = {
"cub_hdp": {
"database": "db_hypercapnia_prepared",
"icu_stay": {
"merge_consecutive": True,
},
"conn_args": {
"password_file": True,
},
}
},
project_vars: dict[str, dict[str, Any]] = {},
load_default_vars: bool = True,
logger_args: dict[str, Any] = {},
) -> None:
# Setup logger before FIRST
self._setup_logger(logger_args)
# Add flag to indicate this is a newly created cohort
# Used in unit tests
self._from_file = False
self.sources = self._load_default_config_data(sources)
self.project_vars = project_vars
# Create temporary directory for storing intermediary data
self.tmpdir_manager = TemporaryDirectoryManager()
# Load data and set primary keys for variables
self._set_obs_level(obs_level)
self._load_obs_level_keys() # Primary keys, tmin, tmax, etc.
self._load_obs_level_data() # Static / constant data
self._obsm = {} # Time series data
# Load default variables
if load_default_vars:
self.load_default_vars()
# Create Change Tracker object - Used for inclusion and exclusion
self._change_tracker = ChangeTrackerPipeline(
primary_key=self.primary_key,
initial_df=self._obs,
initial_description="Observations in database",
)
logger.info(
f"SUCCESS: Extracted data. Cohort has {len(self)} observations ({self.obs_level.lower_name})."
)
def _setup_logger(self, logger_args: dict[str, Any]) -> None:
"""Changes logger configuration."""
self.logger_args = logger_args
utils.configure_logger_level_and_handlers(logger=logger, **self.logger_args)
def _load_default_config_data(
self,
sources: SourceDictPartial,
) -> SourceDict:
return loader.config_loader.load_default_config_data(sources=sources)
def _set_obs_level(self, obs_level: str) -> None:
"""Set the observation level."""
try:
# Calling ObsLevel with a str is not an error
# ObsLevel._missing_ will find the correct ObsLevel
self.obs_level = ObsLevel(obs_level) # type: ignore[call-arg]
except ValueError:
raise ValueError(f"Observation level {obs_level} not supported.")
def _load_obs_level_keys(self) -> None:
"""Mirrors ObsLevel keys on Cohort object."""
self.primary_key = self.obs_level.primary_key
self.t_min = self.obs_level.t_min
self.t_max = self.obs_level.t_max
self.t_eligible = self.obs_level.t_eligible
self.t_outcome = self.obs_level.t_outcome
def _load_obs_level_data(self) -> None:
"""Load the static data for the observation level."""
# Load timestamp for repoducibility
self._data_load_time = datetime.now(timezone.utc)
# Load obs data
self._obs = loader.cohort_loader.load_cohort_data(self.sources, self.obs_level)
# Every data load adds constant variables (not dependent on tmin tmax)
# Save for later referencing
self.constant_vars = [
col for col in self._obs.columns if col not in ObsLevel.primary_keys()
]
[docs]
def load_default_vars(self) -> None:
"""
Load the default variables defined in ``vars.json``. It is recommended to use this after filtering your cohort for eligibility to speed up the process.
Returns:
None: Variables are loaded into the cohort.
Examples:
>>> # Load default variables for an ICU cohort
>>> cohort = Cohort(obs_level="icu_stay", load_default_vars=False)
>>> # Apply filters first (faster)
>>> cohort.include_list([
... {"variable": "age_on_admission", "operation": ">= 18", "label": "Adults"}
... ])
>>> # Then load default variables
>>> cohort.load_default_vars()
"""
vars = loader.var_loader.load_default_variables(self)
with logging_redirect_tqdm(loggers=[logger]):
for var in tqdm(vars, desc="Loading default variables", unit="var"):
try:
self.add_variable(var)
os.system("clear")
except Exception as e:
logger.error(f"Error {e}")
logger.error(f"Error traceback: {traceback.format_exc()}")
logger.info(
f"Could not load variable {var.var_name}. Continuing with next variable..."
)
continue
[docs]
def add_variable(
self,
variable: str | Variable | MultiSourceVariable,
save_as: str | None = None,
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
) -> MultiSourceVariable:
"""Add a variable to the cohort.
You may specify tmin and tmax as a tuple (e.g. ("hospital_admission", "+1d")), in which case it will be relative to the hospital admission time of the patient.
Args:
variable: Variable to add. Either a string with the variable name (from `vars.json`) or a Variable object.
save_as: Name of the column to save the variable as. Defaults to variable name.
tmin: Name of the column to use as tmin or tuple (see description).
tmax: Name of the column to use as tmax or tuple (see description).
Returns:
Variable: The variable object.
Examples:
>>> cohort.add_variable("blood_sodium")
>>> cohort.add_variable(
... variable="anx_dx_covid_19",
... tmin=("hospital_admission", "-1d"),
... tmax=cohort.t_eligible
... )
>>> cohort.add_variable(
... NativeStatic(
... var_name="highest_hct_before_eligible",
... select="!max value",
... base_var='blood_hematokrit',
... tmax=cohort.t_eligible
... )
... )
>>> cohort.add_variable(
... variable='any_med_glu',
... save_as="glucose_prior_eligible",
... tmin=(cohort.t_eligible, "-48h"),
... tmax=cohort.t_eligible
... )
"""
# Backup
obs_backup = self._obs.clone()
obs_backup_idset = set(obs_backup.select(self.primary_key).unique().to_series())
# Load
if isinstance(variable, MultiSourceVariable):
var = variable
else:
var = self._load_variable(variable, tmin, tmax)
# Extract
data = var.extract(self, with_source_column=var.dynamic)
# Save
self._save_variable(
var_name=var.var_name,
var_dynamic=var.dynamic,
var_data=data,
save_as=save_as,
)
# Verify cohort integrity & restore if compromised
new_obs_idset = set(self._obs.select(self.primary_key).unique().to_series())
if not new_obs_idset == obs_backup_idset:
# If mismatch - restore cohort
self._obs = obs_backup
raise CohortDataError(
"Adding variable altered cohort composition - data compromised."
)
# Assert no duplicates
self._validate_cohort()
return var
def _load_variable(
self,
variable: str | Variable,
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
) -> MultiSourceVariable:
"""Loads a MultiSourceVariable from JSON template or specified Variable object"""
if isinstance(variable, str):
tmin = tmin if tmin is not None else self.t_min
tmax = tmax if tmax is not None else self.t_max
return loader.var_loader.load_variable(
var_name=variable, cohort=self, tmin=tmin, tmax=tmax
)
else:
assert (
tmin is None and tmax is None
), "Please specify tmin and tmax directly in the Variable object"
source = utils.guess_variable_source(variable) or "Not applicable"
return MultiSourceVariable(
var_name=variable.var_name, variables={source: variable}
)
def _save_variable(
self,
var_name: str,
var_dynamic: bool,
var_data: pl.DataFrame | pd.DataFrame,
save_as=None,
) -> None:
"""Save a Variable object to the cohort.
Will either add a column to obs or add an entry to obsm (for dynamic variables).
Args:
var: Variable to save. (Must be already extracted)
save_as: Name of the column to save the variable as. Defaults to variable name.
Returns:
None: Variable is saved to the cohort.
"""
# NOTE: Won't allow "" as var name
save_as = save_as or var_name
data = utils.convert_to_polars_df(var_data)
if var_dynamic:
self._obsm[save_as] = data
else:
# Assert single observation per primary key
assert (
data.select(self.primary_key).is_unique().all()
), f"Variable data contains duplicates for {self.primary_key} ({save_as})."
# Find columns that start with save_as
drop_cols = self._obs.select(pl.selectors.starts_with(f"{save_as}")).columns
if drop_cols:
logger.info(
f"DROP existing variable(s): {', '.join(drop_cols)} (N={len(self._obs)})"
)
self._obs = self._obs.drop(drop_cols)
del drop_cols # Clear to avoid accidental re-use later
logger.debug(
f"Variable data size: {data.shape}, obs size: {self._obs.shape}"
)
# Rename columns based on conditions
logger.info("Renaming columns...")
if (
save_as != var_name
and len(data.columns) > 2 # more than one column + primary key
):
# Create a mapping for renaming columns
rename_map = {}
for col in data.columns:
if col not in [self.primary_key, save_as]:
parts = col.rsplit("_", 1)
if len(parts) > 1:
rename_map[col] = f"{save_as}_{parts[1]}"
if rename_map:
data = data.rename(rename_map)
else:
# Simple rename of the variable column
if var_name in data.columns:
data = data.rename({var_name: save_as})
# Assert that only expected columns are present
logger.debug("Simple rename check.")
assert (
data.select(self.primary_key, pl.selectors.starts_with(save_as)).columns
== data.columns
), "Variable data contains unexpected columns."
# Join with self.obs
logger.debug("Joining variable data with obs")
self._obs = self._obs.join(data, on=self.primary_key, how="left")
logger.info(f"SUCCESS: Saved variable: {save_as}")
gc.collect()
logger.info(f"Data summary:\n{data.describe()!r}")
if self.primary_key in data.columns:
nunique_info = str(data.select(pl.col(self.primary_key)).n_unique())
else:
nunique_info = "Not available"
logger.info(f"Number of unique non-NA entries: {nunique_info}")
# Time anchors, inclusion and exclusion
[docs]
def set_t_eligible(self, t_eligible: str, drop_ineligible: bool = True) -> None:
"""
Set the time anchor for eligibility. This can be referenced as cohort.t_eligible throughout the process and is required to add inclusion or exclusion criteria.
Args:
t_eligible: Name of the column to use as t_eligible.
drop_ineligible: Whether to drop ineligible patients. Defaults to True.
Returns:
None: t_eligible is set.
Examples:
>>> # Add a suitable time-anchor variable
>>> cohort.add_variable(NativeStatic(
... var_name="spo2_lt_90",
... base_var="spo2",
... select="!first recordtime",
... where="value < 90",
... ))
>>> # Set the time anchor for eligibility
>>> cohort.set_t_eligible("spo2_lt_90")
"""
self._assert_datetime_col(t_eligible)
if self.t_eligible is not None:
warnings.warn(
f"WARNING: t_eligible already set to {self.t_eligible}. Will overwrite and set to {t_eligible}.\nCAVE! Previously ineligible patients will not be restored.",
UserWarning,
)
self.t_eligible = t_eligible
if drop_ineligible:
self._drop_ineligible()
[docs]
def set_t_outcome(self, t_outcome: str) -> None:
"""
Set the time anchor for outcome. This can be referenced as cohort.t_outcome throughout the process and is recommended to specify for your study.
Args:
t_outcome (str): Name of the column to use as t_outcome.
Returns:
None: t_outcome is set.
Examples:
>>> cohort.set_t_outcome("hospital_discharge")
"""
self._assert_datetime_col(t_outcome)
if self.t_outcome is not None:
warnings.warn(
f"WARNING: t_outcome already set to {self.t_outcome}. Will overwrite and set to {t_outcome}.",
UserWarning,
)
self.t_outcome = t_outcome
def _assert_datetime_col(self, col: str) -> None:
assert col in self._obs.columns, f"Column {col} not found in obs."
assert (
self._obs[col].dtype == pl.Datetime
), f"Column {col} is not a datetime column."
def _drop_ineligible(self) -> None:
"""Drops ineligible observations (where t_eligible is NaT)."""
mask = self._obs[self.t_eligible].is_null()
# Log absolute and relative loss
n_dropped = mask.sum()
total = len(self._obs)
logger.info(
f"DROP: {n_dropped} rows ({n_dropped/total:.2%}) due to {self.t_eligible} is NaT"
)
# Drop ineligible observations and update state
self._obs = self._obs.remove(mask)
self._change_tracker.add_step(
description=f"Excluded {n_dropped} rows due to {self.t_eligible} is NaT",
after_df=self._obs,
)
[docs]
def change_tracker(
self, description: str, mode: Literal["include", "exclude"] = "include"
):
"""Return a context manager to group cohort edits and record a single ChangeTracker state on exit.
Example:
with cohort.change_tracker("Adults", mode="include") as track:
track.filter(pl.col("age_on_admission") >= 18)
"""
return ChangeTrackerContext(
pipeline=self._change_tracker,
cohort=self,
description=description,
mode=mode,
)
[docs]
def include(self, *args, **kwargs) -> None:
"""
Add an inclusion criterion to the cohort. It is recommended to use ``Cohort.include_list()`` and add all of your inclusion criteria at once. However, if you need to specify criteria at a later stage, you can use this method.
Warning:
You should call ``Cohort.include_list()`` before calling ``Cohort.include()`` to ensure that the inclusion criteria are properly tracked.
Args:
variable (str | Variable),
operation (str),
label (str),
operations_done (str)
[Optional: tmin, tmax]
Returns:
None: Criterion is added to the cohort.
Note:
``operation`` is passed to ``pandas.DataFrame.query``, which uses a `slightly modified Python syntax <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html>`_. Also, if you specify "true"/"True" or "false"/"False" as a value for ``operation``, it will be converted to "== True" or "== False", respectively.
Examples:
>>> cohort.include(
... variable="age_on_admission",
... operation=">= 18",
... label="Adult",
... operations_done="Include only adult patients"
... )
"""
return self._include_exclude("include", *args, **kwargs)
[docs]
def exclude(self, *args, **kwargs) -> None:
"""
Add an exclusion criterion to the cohort. It is recommended to use ``Cohort.exclude_list()`` and add all of your exclusion criteria at once. However, if you need to specify criteria at a later stage, you can use this method.
Warning:
You should call ``Cohort.exclude_list()`` before calling ``Cohort.exclude()`` to ensure that the exclusion criteria are properly tracked.
Args:
variable (str | Variable),
operation (str),
label (str),
operations_done (str)
[Optional: tmin, tmax]
Returns:
None: Criterion is added to the cohort.
Note:
``operation`` is passed to ``pandas.DataFrame.query``, which uses a `slightly modified Python syntax <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html>`_. Also, if you specify "true"/"True" or "false"/"False" as a value for ``operation``, it will be converted to "== True" or "== False", respectively.
Examples:
>>> cohort.exclude(
... variable="elix_total",
... operation="> 20",
... operations_done="Exclude patients with high Elixhauser score"
... )
"""
return self._include_exclude("exclude", *args, **kwargs)
def _include_exclude(
self,
mode: Literal["include", "exclude"],
variable: str | Variable,
operation: str,
operations_done: str | None = None,
label: str | None = None,
allow_obs: bool = False,
tmin: str | None = None,
tmax: str | None = None,
) -> None:
"""
Add an inclusion or exclusion criterion to the cohort.
Args:
mode (Literal["include", "exclude"]): Whether to add an inclusion or exclusion criterion.
variable (str | Variable),
operation (str),
operations_done (str | None),
label (str | None),
allow_obs (bool): Allow using a stored variable in obs insteead of re-extracting. CAVE: This can be dangerous when trying to set custom time bounds.
[Optional: tmin, tmax]
Returns:
None: Criterion is added to the cohort.
"""
if len(self._obs) != self._change_tracker.current_snapshot.nuniques:
warnings.warn(
"WARNING: Current change tracker state is not equal to cohort size. This could indicate that the cohort was modified without proper tracking. Inclusion/Exclusion charts might be inaccurate.\n"
+ f"Cohort.obs: {len(self._obs)} observations / ChangeTracker: {self._change_tracker.current_snapshot.nuniques} observations",
UserWarning,
)
if operation.lower() in ["true", "false"]:
operation = "== True" if operation.lower() == "true" else "== False"
if tmax is None:
assert (
self.t_eligible is not None
), "t_eligible is not set. Please set t_eligible before adding inclusion criteria or set tmax manually (not recommended)."
tmax = self.t_eligible
if tmin is None:
tmin = "hospital_admission"
# Check if this is a defined variable or a custom variable
if variable in self.constant_vars or (
allow_obs and variable in self._obs.columns
):
data = self._obs.select([variable, self.primary_key])
var_name = variable
else:
var = self._load_variable(variable, tmin, tmax)
data = var.extract(self)
var_name = var.var_name
operations_done = operations_done or f"{var_name}{operation}"
try:
op_ids_expr = eval(f"pl.col('{var_name}'){operation}")
op_ids = data.filter(op_ids_expr)[self.primary_key]
except Exception as e:
logger.error(
f"Failed to add {mode} criterion ({variable}{operation}) with data:\n{data}\n{e}\n{traceback.format_exc()}"
)
match mode:
case "include":
self._obs = self._obs.filter(pl.col(self.primary_key).is_in(op_ids))
self._change_tracker.add_step(
description=f"Included {operations_done}",
group=label,
after_df=self._obs,
)
case "exclude":
self._obs = self._obs.remove(pl.col(self.primary_key).is_in(op_ids))
self._change_tracker.add_step(
description=f"Excluded {operations_done}",
group=label,
after_df=self._obs,
)
[docs]
def include_list(self, inclusion_list: list[dict] = []) -> ChangeTrackerPipeline:
"""
Add an inclusion criteria to the cohort.
Args:
inclusion_list (list): List of inclusion criteria. Must include a dictionary with keys:
* ``variable`` (str | Variable): Variable to use for exclusion
* ``operation`` (str): Operation to apply (e.g., "> 5", "== True")
* ``label`` (str): Short label for the exclusion step
* ``operations_done`` (str): Detailed description of what this exclusion does
* ``tmin`` (str, optional): Start time for variable extraction
* ``tmax`` (str, optional): End time for variable extraction
Returns:
ct (CohortTracker): CohortTracker object, can be used to plot inclusion chart
Note:
Per default, all inclusion criteria are applied from ``tmin=cohort.tmin`` to ``tmax=cohort.t_eligible``. This is recommended to avoid introducing immortality biases. However, in some cases you might want to set custom time bounds.
Examples:
>>> ct = cohort.include_list([
... {
... "variable": "age_on_admission",
... "operation": ">= 18",
... "label": "Adult patients",
... "operations_done": "Excluded patients under 18 years old"
... }
... ])
>>> ct.create_flowchart()
"""
for inclusion in inclusion_list:
self.include(**inclusion)
return self._change_tracker
[docs]
def exclude_list(self, exclusion_list: list[dict] = []) -> ChangeTrackerPipeline:
"""
Add an exclusion criteria to the cohort.
Args:
exclusion_list (list): List of exclusion criteria. Each criterion is a dictionary containing:
* ``variable`` (str | Variable): Variable to use for exclusion
* ``operation`` (str): Operation to apply (e.g., "> 5", "== True")
* ``label`` (str): Short label for the exclusion step
* ``operations_done`` (str): Detailed description of what this exclusion does
* ``tmin`` (str, optional): Start time for variable extraction
* ``tmax`` (str, optional): End time for variable extraction
Returns:
ct (CohortTracker): CohortTracker object, can be used to plot exclusion chart
Note:
Per default, all exclusion criteria are applied from ``tmin=cohort.tmin`` to ``tmax=cohort.t_eligible``. This is recommended to avoid introducing immortality biases. However, in some cases you might want to set custom time bounds.
Examples:
>>> ct = cohort.exclude_list([
... {
... "variable": "any_rrt_icu",
... "operation": "true",
... "label": "No RRT",
... "operations_done": "Excluded RRT before hypernatremia"
... },
... {
... "variable": "any_dx_tbi",
... "operation": "true",
... "label": "No TBI",
... "operations_done": "Excluded TBI before hypernatremia"
... },
... {
... "variable": NativeStatic(
... var_name="sodium_count",
... select="!count value",
... base_var="blood_sodium"),
... "operation": "< 1",
... "label": "Final cohort",
... "operations_done": "Excluded cases with less than 1 sodium measurement after hypernatremia",
... "tmin": cohort.t_eligible,
... "tmax": "hospital_discharge"
... }
... ])
>>> ct.create_flowchart() # Plot the exclusion flowchart
"""
for exclusion in exclusion_list:
self.exclude(**exclusion)
return self._change_tracker
[docs]
def add_variable_definition(self, var_name: str, var_dict: dict) -> None:
"""Add or update a local variable definition.
Args:
var_name (str): Name of the variable.
var_dict (dict): Dictionary containing variable definition. Can be partial -
missing fields will be inherited from global definition.
Examples:
Add a completely new variable:
>>> cohort.add_variable_definition("my_new_var", {
... "type": "native_dynamic",
... "table": "it_ishmed_labor",
... "where": "c_katalog_leistungtext LIKE '%new%'",
... "value_dtype": "DOUBLE",
... "cleaning": {"value": {"low": 100, "high": 150}}
... })
Partially override existing variable:
>>> cohort.add_variable_definition("blood_sodium", {
... "where": "c_katalog_leistungtext LIKE '%custom_sodium%'"
... })
"""
self.project_vars.setdefault(var_name, {})
self.project_vars[var_name].update(var_dict)
[docs]
def add_inclusion(self) -> None:
raise NotImplementedError("Use include_list instead")
[docs]
def add_exclusion(self) -> None:
raise NotImplementedError("Use exclude_list instead")
def _validate_cohort(self) -> None:
"""Validate data integrity"""
# Check for unique primary keys in .obs
if not self._obs.select(pl.col(self.primary_key)).is_unique().all():
raise CohortDataError(
f"Duplicate entries found in obs for primary key '{self.primary_key}'."
)
[docs]
def to_csv(self, folder: str) -> None:
"""
Save the cohort to CSV files.
Args:
folder (str): Path to the folder where CSV files will be saved.
Examples:
>>> cohort.to_csv("output_data")
>>> # Creates:
>>> # output_data/_obs.csv
>>> # output_data/blood_sodium.csv
>>> # output_data/heart_rate.csv
>>> # ... (one file per variable)
"""
# NOTE: Convert to pandas to workaround the missing support for nested data
logger.info(f"Saving cohort to {folder}")
os.makedirs(folder, exist_ok=True)
logger.info("Preparing for export...")
self._obs.to_pandas().to_csv(os.path.join(folder, "_obs.csv"), index=False)
logger.info("Saving variables...")
for var_name, var_data in self._obsm.items():
var_data.to_pandas().to_csv(
os.path.join(folder, f"{var_name}.csv"), index=False
)
logger.info("Done!")
[docs]
def save(self, filename: str) -> None:
"""
Save the cohort to a single compressed .corr2 archive (.tar.zst equivalent). Saves cohort.__dict__ (excluding obs, obsm, variables, conn) to state.pkl, obs to obs.parquet, and each obsm DataFrame to obsm_<var_name>.parquet in a temp dir.
Args:
filename: Path to the .corr2 archive
Returns:
None
"""
if filename.endswith(".corr"):
filename = filename + "2"
warnings.warn(
"The new file format is .corr2. Will use .corr2 instead of .corr.",
UserWarning,
)
elif "." not in filename:
filename = filename + ".corr2"
elif not filename.endswith(".corr2"):
raise ValueError(
"Unsupported file format. Please provide a filename ending with .corr2 or no file ending at all."
)
self._save_corr2(filename)
def _save_corr2(self, filename: str) -> None:
"""
Save the cohort to a single compressed .corr2 archive (.tar.zst equivalent).
Saves cohort.__dict__ (excluding obs, obsm, variables, conn) to state.pkl,
obs to obs.parquet, and each obsm DataFrame to obsm_<var_name>.parquet in a temp dir
Args:
filename: Path to the .corr2 archive
Returns:
None
"""
# NOTE: Default path is already set for the corr_vars module by the global tempfile.tempdir
with tempfile.TemporaryDirectory() as tmpdir:
# Save obs and obsm as parquet files
self._obs.write_parquet(os.path.join(tmpdir, "obs.parquet"))
for var_name, df in self._obsm.items():
df.write_parquet(os.path.join(tmpdir, f"obsm_{var_name}.parquet"))
# Prepare dict to save (exclude obs, obsm, variables, conn)
state = self.__dict__.copy()
state.pop("_obs", None)
state.pop("_obsm", None)
state.pop("obs", None)
state.pop("obsm", None)
state.pop("variables", None)
state.pop("conn", None)
# Pop password before saving to avoid unencrypted password in the archive
for source in state["sources"]:
if (
"conn_args" in state["sources"][source]
and "password" in state["sources"][source]["conn_args"]
):
state["sources"][source]["conn_args"].pop("password")
# Save the rest
state_path = os.path.join(tmpdir, "state.pkl")
with open(state_path, "wb") as f:
pickle.dump(state, f)
# Tar all files in tmpdir
tar_path = os.path.join(tmpdir, "archive.tar")
with tarfile.open(tar_path, "w") as tar:
for file in os.listdir(tmpdir):
if file != "archive.tar":
tar.add(os.path.join(tmpdir, file), arcname=file)
# Compress tar with zstandard
cctx = zstd.ZstdCompressor(level=3)
with open(tar_path, "rb") as tar_in, open(filename, "wb") as out:
out.write(cctx.compress(tar_in.read()))
logger.info(
f"SUCCESS: Saved cohort to {filename} ({len(self._obsm)} dynamic variables, {len(self._obs)} observations [{self.obs_level.lower_name}])"
)
[docs]
@classmethod
def load(cls, filename: str) -> Cohort:
if filename.endswith(".corr2"):
return cls._load_corr2(filename)
else:
raise ValueError(
"Unsupported file format. The new Cohort module only supports .corr2 files."
)
@classmethod
def _load_corr2(cls, filename: str, password: str | None = None) -> Cohort:
"""
Load a cohort from a single compressed .corr2 archive (.tar.zst equivalent).
Loads cohort.__dict__ from state.pkl, obs from obs.parquet, and obsm from obsm_<var_name>.parquet files
extracted from the archive.
Args:
filename: Path to the .corr2 archive
password: Password to restore for sources with password field
Returns:
Cohort: The loaded cohort
"""
# BACKWARDS COMPATIBLITY SECTION:
import io
CLASS_RENAMES = {
"corr_vars.utils.helpers": {
"ChangeTracker": "corr_vars.core.change_tracker.ChangeTrackerPipeline"
}
}
class _RenamingUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module in CLASS_RENAMES and name in CLASS_RENAMES[module]:
new_path = CLASS_RENAMES[module][name]
new_mod, new_cls = new_path.rsplit(".", 1)
mod = __import__(new_mod, fromlist=[new_cls])
return getattr(mod, new_cls)
# Default
return super().find_class(module, name)
def load_with_rename(data: bytes):
return _RenamingUnpickler(io.BytesIO(data)).load()
# LOAD FILE
# NOTE: Default path is already set for the corr_vars module by the global tempfile.tempdir
with tempfile.TemporaryDirectory() as tmpdir:
tar_path = os.path.join(tmpdir, "archive.tar")
dctx = zstd.ZstdDecompressor()
with open(filename, "rb") as inp, open(tar_path, "wb") as out:
out.write(dctx.decompress(inp.read()))
with tarfile.open(tar_path, "r") as tar:
tar.extractall(tmpdir)
# Load state dict
state_path = os.path.join(tmpdir, "state.pkl")
with open(state_path, "rb") as f:
raw = f.read()
state = load_with_rename(raw)
# Create instance
cohort = cls.__new__(cls)
cohort.__dict__.update(state)
# Support older Cohort objects
if getattr(cohort, "sources", None) is not None:
cohort.sources = cohort._load_default_config_data(cohort.sources)
if getattr(cohort, "tmpdir_manager", None) is None:
cohort.tmpdir_manager = TemporaryDirectoryManager() # type: ignore[misc]
if getattr(cohort, "obs_level", None) and isinstance(cohort.obs_level, str):
cohort.obs_level = ObsLevel(cohort.obs_level)
old_tmpdir_path = getattr(cohort, "tmpdir_path", None)
# Ensure tmpdir exists and is valid
tmpdir_path = getattr(cohort, "_current_tmpdir_path", old_tmpdir_path)
cohort._current_tmpdir_path = cohort.tmpdir_manager.create_tmpdir(
tmpdir_path=tmpdir_path, check_existing=True
)
# Load parquet files
obs_path = os.path.join(tmpdir, "obs.parquet")
cohort._obs = pl.read_parquet(obs_path)
cohort._obsm = {}
for file in os.listdir(tmpdir):
if file.startswith("obsm_") and file.endswith(".parquet"):
var_name = file[5:-8] # remove 'obsm_' and '.parquet'
cohort._obsm[var_name] = pl.read_parquet(os.path.join(tmpdir, file))
# Restore password, which was removed to avoid unencrypted password in the archive
for source in state["sources"]:
if "conn_args" in state["sources"][source]:
state["sources"][source]["conn_args"]["password"] = password
# Add flag to indicate this is a old cohort
# Used in unit tests
cohort._from_file = True
logger.info(
f"SUCCESS: Loaded cohort from {filename} ({len(cohort._obsm)} dynamic variables, {len(cohort._obs)} observations [{cohort.obs_level.lower_name}])"
)
# Validate
cohort._validate_cohort()
return cohort
# Property methods
# We use _obs and _obsm to store underlying data so that the interface can be modified to pandas for the legacy version
@property
def obs(self) -> pl.DataFrame:
return self._obs
@obs.setter
def obs(self, value: pl.DataFrame | pl.LazyFrame | pd.DataFrame) -> None:
self._obs = utils.convert_to_polars_df(value)
@obs.deleter
def obs(self) -> None:
warnings.warn("Can't delete Cohort.obs!", UserWarning)
@property
def obsm(self) -> ObsmDict:
return ObsmDict(self._obsm)
@property
def tmpdir_path(self):
self._current_tmpdir_path = self.tmpdir_manager.path
return self._current_tmpdir_path
# System methods
def __repr__(self) -> str:
return f"Cohort(obs_level={self.obs_level.lower_name}, len(obs)={len(self)})"
def __str__(self) -> str:
return textwrap.dedent(
f"""\
Cohort object
obs_level: {self.obs_level.lower_name}
len(obs): {len(self)}
obs (first 5 rows):
{self._obs.head()}
"""
)
def __len__(self) -> int:
return len(self._obs)
def __iter__(self) -> Iterator[pl.Series]:
return iter(self._obs)
@overload
def __getitem__(self, key: tuple[SingleIndexSelector, SingleColSelector]) -> Any:
...
@overload
def __getitem__( # type: ignore[overload-overlap]
self, key: str | tuple[MultiIndexSelector, SingleColSelector]
) -> pl.Series:
...
@overload
def __getitem__(
self,
key: (
SingleIndexSelector
| MultiIndexSelector
| MultiColSelector
| tuple[SingleIndexSelector, MultiColSelector]
| tuple[MultiIndexSelector, MultiColSelector]
),
) -> pl.DataFrame:
...
def __getitem__(
self,
item: (
SingleIndexSelector
| SingleColSelector
| MultiColSelector
| MultiIndexSelector
| tuple[SingleIndexSelector, SingleColSelector]
| tuple[SingleIndexSelector, MultiColSelector]
| tuple[MultiIndexSelector, SingleColSelector]
| tuple[MultiIndexSelector, MultiColSelector]
),
) -> pl.DataFrame | pl.Series | Any:
return self._obs[item]
def __setitem__(
self, item: ArrayLike | str | None, value: ArrayLike | None
) -> None:
self._obs = self._obs.with_columns(pl.Series(item, value))
def __contains__(self, value: str) -> bool:
return value in self.obs
# Other print methods
def _repr_html_(self) -> str:
return textwrap.dedent(
f"""\
<article style="margin-bottom: 2em">
<h1 style="font-size: 1.5em; font-weight: 800; text-decoration: underline solid 2px; margin-block: .5em;">Cohort object</h1>
<ul style="font-family: monospace; font-size: 1em; list-style: none outside none; margin-block: .5em; padding: 0; display: flex; flex-direction: column; gap: .5em">
<li><b>obs_level:</b> {self.obs_level.lower_name}</li>
<li><b>len(obs):</b> {len(self)}</li>
</ul>
</article>
<div style="display: flex; flex-direction: column; gap: .5em">
<div style="font-family: monospace; font-size: 1em;"><b>obs (first 5 rows):</b></div>
<div style="overflow-x: auto; width: 100%">
{self._obs.head()._repr_html_()}
</div>
</div>
"""
)
[docs]
def debug_print(self) -> None:
"""Print debug information about the cohort. Please use this if you are creating a GitHub issue.
Returns:
None
"""
utils.print_cohort_debug_info(self)
utils.print_debug_info()
# Utils methods
[docs]
def to_stata(
self,
df: pd.DataFrame = None,
convert_dates: dict[Hashable, str] | None = None,
write_index: bool = True,
to_file: str | None = None,
) -> pd.DataFrame | None:
"""
Convert the cohort to a Stata DataFrame. You may use cohort.stata to access the dataframe directly.
Note that you need to save it to a top-level variable to access it via %%stata.
Args:
df (pd.DataFrame): The DataFrame to be converted to Stata format. Will default to
the obs DataFrame if unspecified (default: None)
convert_dates (dict[Hashable, str]): Dictionary of columns to convert to Stata date format.
write_index (bool): Whether to write the index as a column.
to_file (str | None): Path to save as .dta file. If left unspecified, the DataFrame will not be saved.
Returns:
pd.DataFrame: A Pandas Dataframe compatible with Stata if `to_file` is None.
"""
stata_df = df or self._obs.to_pandas()
return utils.convert_to_stata(
stata_df,
convert_dates=convert_dates,
write_index=write_index,
to_file=to_file,
)
stata = property(to_stata)
[docs]
def to_tableone(
self,
ignore_cols: list | str = [],
groupby: str | None = None,
filter: str | None = None,
pval: bool = False,
normal_cols: list[str] = [],
**kwargs,
) -> TableOne:
"""
Create a `TableOne <https://tableone.readthedocs.io/en/latest/index.html>`_ object for the cohort.
Args:
ignore_cols (list | str): Column(s) to ignore.
groupby (str): Column to group by.
filter (str): Filter to apply to the data.
pval (bool): Whether to calculate p-values.
normal_cols (list[str]): Columns to treat as normally distributed.
**kwargs: Additional arguments to pass to TableOne.
Returns:
TableOne: A TableOne object.
Examples:
>>> tableone = cohort.tableone()
>>> print(tableone)
>>> tableone.to_csv("tableone.csv")
>>> tableone = cohort.tableone(groupby="sex", pval=False)
>>> print(tableone)
>>> tableone.to_csv("tableone_sex.csv")
"""
tab1_df = self._obs.to_pandas()
return utils.convert_to_tableone(
tab1_df,
ignore_cols=ignore_cols,
groupby=groupby,
filter=filter,
pval=pval,
normal_cols=normal_cols,
**kwargs,
)
tableone = property(to_tableone)