Source code for corr_vars.core.cohort

from __future__ import annotations

from datetime import datetime, timezone
import os
import pickle
import tarfile
import tempfile
import textwrap
import traceback
import warnings
import gc

import pandas as pd
import polars as pl
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
import zstandard as zstd

from corr_vars import logger
from corr_vars.core.change_tracker import ChangeTrackerPipeline, ChangeTrackerContext
from corr_vars.core.file_manager import TemporaryDirectoryManager
from corr_vars.core.variable import Variable
import corr_vars.sources as loader
from corr_vars.sources.var_loader import MultiSourceVariable
import corr_vars.utils as utils

from collections.abc import Iterator, MutableMapping, Hashable
from corr_vars.definitions import ObsLevel
from typing import Final, Literal, Any, overload, TYPE_CHECKING

if TYPE_CHECKING:
    from corr_vars.definitions import (
        ArrayLike,
        ObsLevelNames,
        SourceDictPartial,
        SourceDict,
        TimeBoundColumn,
    )
    from polars._typing import (
        SingleIndexSelector,
        SingleColSelector,
        MultiColSelector,
        MultiIndexSelector,
    )
    from tableone import TableOne


class CohortDataError(Exception):
    """Raised upon cohort data errors."""

    pass


class VariableNotFoundError(Exception):
    """Raised when a variable is not found in the cohort."""

    pass


class ObsmDict(MutableMapping[str, pl.DataFrame]):
    """Dictionary-like wrapper for obsm data providing convenient access and display.

    Args:
        data (dict): Dictionary of polars DataFrames containing dynamic variable data.
    """

    def __init__(self, data: dict[str, pl.DataFrame]) -> None:
        self._data = data

    def __getitem__(self, key) -> pl.DataFrame:
        try:
            return self._data[key]
        except KeyError:
            available_vars = list(self._data.keys())
            if available_vars:
                error_msg = (
                    f"Variable '{key}' not found in cohort.obsm. "
                    f"Available variables: {', '.join(available_vars)}"
                )
            else:
                error_msg = (
                    f"Variable '{key}' not found in cohort.obsm. "
                    f"No variables have been extracted yet. Use cohort.add_variable() to extract variables."
                )
            raise VariableNotFoundError(error_msg) from None

    def __contains__(self, key):
        return key in self._data

    def __setitem__(self, key, value) -> None:
        self._data[key] = value

    def __delitem__(self, key) -> None:
        del self._data[key]

    def __iter__(self) -> Iterator[str]:
        return iter(self._data)

    def __len__(self) -> int:
        return len(self._data)

    def __repr__(self) -> str:
        lines = [f"{k}: shape={df.shape}" for k, df in self._data.items()]
        return "ObsmDict(\n  " + "\n  ".join(lines) + "\n)" if lines else "ObsmDict(Ø)"


[docs] class Cohort: """Class to build a cohort in the CORR database. Args: obs_level (Literal["icu_stay", "hospital_stay", "procedure"]): Observation level (default: "icu_stay"). sources (dict[str, dict]): Dictionary of data sources to use for data extraction. Available options are "cub_hdp", "reprodicu", "co5". Source configurations: - cub_hdp: Accepted keys: database, password_file, merge_consecutive, filters, conn_args [remote_hostname (str), username (str)] - reprodicu: Accepted keys: path, exclude_datasets - co5: Not implemented yet. Data request is still pending. Note: reprodicu does not yet implement variable extraction, only cohort data. project_vars (dict): Dictionary with local variable definitions (default: {}). load_default_vars (bool): Whether to load the default variables (default: True). logger_args (dict): Dictionary of Logging configurations [level (int), file_path (str), file_mode (str), verbose_fmt (bool), colored_output (bool), formatted_numbers (bool)] (default: {}). Attributes: obs (pl.DataFrame): Static data for each observation. Contains one row per observation (e.g., ICU stay) with columns for static variables like demographics and outcomes. Example: >>> cohort.obs patient_id case_id icu_stay_id icu_admission icu_discharge sex ... inhospital_death 0 P001 C001 C001_1 2023-01-01 08:30:00 2023-01-03 12:00:00 M ... False 1 P001 C001 C001_2 2023-01-03 14:20:00 2023-01-05 16:30:00 M ... False 2 P002 C002 C002_1 2023-01-02 09:15:00 2023-01-04 10:30:00 F ... False 3 P003 C003 C003_1 2023-01-04 11:45:00 2023-01-07 13:20:00 F ... True ... obsm (dict of pl.DataFrame): Dynamic data stored as dictionary of DataFrames. Each DataFrame contains time-series data for a variable with columns: - recordtime: Timestamp of the measurement - value: Value of the measurement - recordtime_end: End time (only for duration-based variables like therapies) - description: Additional information (e.g., medication names) Example: >>> cohort.obsm["blood_sodium"] icu_stay_id recordtime value 0 C001_1 2023-01-01 09:30:00 138 1 C001_1 2023-01-02 10:15:00 141 2 C001_2 2023-01-03 15:00:00 137 3 C002_1 2023-01-02 10:00:00 142 4 C003_1 2023-01-04 12:30:00 139 ... Notes: - For large cohorts, set ``load_default_vars=False`` to speed up the extraction. You can use pre-extracted cohorts as starting points and load them using ``Cohort.load()``. - Variables can be added using ``cohort.add_variable()``. Static variables will be added to ``obs``, dynamic variables to ``obsm``. - ``sources["cub_hdp"]["filters"]`` also allows a special shorthand "_dx" to extract the hospital admissions of the last x months, useful for debugging/prototyping. For example use "_d2" to extract every admission of the last 2 months. Examples: Create a new cohort: >>> cohort = Cohort( ... obs_level="icu_stay", ... load_default_vars=False, ... sources={ ... "cub_hdp": { ... "database": "db_hypercapnia_prepared", ... "password_file": True, ... "merge_consecutive": False}, ... "reprodicu": { ... "path": "/data02/projects/reprodicubility/reprodICU/reprodICU_files"} ... }) Access static data: >>> cohort.obs.select("age_on_admission") # Get age for all patients >>> cohort.obs.filter(pl.col("sex") == "M") # Filter for male patients Access time-series data: >>> cohort.obsm["blood_sodium"] # Get all blood sodium measurements >>> # Get blood sodium measurements for a specific observation >>> cohort.obsm["blood_sodium"].filter(pl.col(cohort.primary_key) == "12345") """ # Data variables _obs: pl.DataFrame _obsm: dict[str, pl.DataFrame] constant_vars: list[str] # Observation level variables obs_level: ObsLevel primary_key: str t_min: str t_max: str t_eligible: str t_outcome: str # Changetracker _change_tracker: Final[ChangeTrackerPipeline] # Configurations logger_args: dict[str, Any] sources: Final[SourceDict] project_vars: dict[str, dict[str, Any]] # Temporary directory variables tmpdir_manager: Final[TemporaryDirectoryManager] _current_tmpdir_path: str # Flags _from_file: bool _data_load_time: datetime # Properties (defined at the bottom of Cohort class) # obs # tmpdir_path # stata # tableone def __init__( self, obs_level: ObsLevelNames = "icu_stay", sources: SourceDictPartial = { "cub_hdp": { "database": "db_hypercapnia_prepared", "icu_stay": { "merge_consecutive": True, }, "conn_args": { "password_file": True, }, } }, project_vars: dict[str, dict[str, Any]] = {}, load_default_vars: bool = True, logger_args: dict[str, Any] = {}, ) -> None: # Setup logger before FIRST self._setup_logger(logger_args) # Add flag to indicate this is a newly created cohort # Used in unit tests self._from_file = False self.sources = self._load_default_config_data(sources) self.project_vars = project_vars # Create temporary directory for storing intermediary data self.tmpdir_manager = TemporaryDirectoryManager() # Load data and set primary keys for variables self._set_obs_level(obs_level) self._load_obs_level_keys() # Primary keys, tmin, tmax, etc. self._load_obs_level_data() # Static / constant data self._obsm = {} # Time series data # Load default variables if load_default_vars: self.load_default_vars() # Create Change Tracker object - Used for inclusion and exclusion self._change_tracker = ChangeTrackerPipeline( primary_key=self.primary_key, initial_df=self._obs, initial_description="Observations in database", ) logger.info( f"SUCCESS: Extracted data. Cohort has {len(self)} observations ({self.obs_level.lower_name})." ) def _setup_logger(self, logger_args: dict[str, Any]) -> None: """Changes logger configuration.""" self.logger_args = logger_args utils.configure_logger_level_and_handlers(logger=logger, **self.logger_args) def _load_default_config_data( self, sources: SourceDictPartial, ) -> SourceDict: return loader.config_loader.load_default_config_data(sources=sources) def _set_obs_level(self, obs_level: str) -> None: """Set the observation level.""" try: # Calling ObsLevel with a str is not an error # ObsLevel._missing_ will find the correct ObsLevel self.obs_level = ObsLevel(obs_level) # type: ignore[call-arg] except ValueError: raise ValueError(f"Observation level {obs_level} not supported.") def _load_obs_level_keys(self) -> None: """Mirrors ObsLevel keys on Cohort object.""" self.primary_key = self.obs_level.primary_key self.t_min = self.obs_level.t_min self.t_max = self.obs_level.t_max self.t_eligible = self.obs_level.t_eligible self.t_outcome = self.obs_level.t_outcome def _load_obs_level_data(self) -> None: """Load the static data for the observation level.""" # Load timestamp for repoducibility self._data_load_time = datetime.now(timezone.utc) # Load obs data self._obs = loader.cohort_loader.load_cohort_data(self.sources, self.obs_level) # Every data load adds constant variables (not dependent on tmin tmax) # Save for later referencing self.constant_vars = [ col for col in self._obs.columns if col not in ObsLevel.primary_keys() ]
[docs] def load_default_vars(self) -> None: """ Load the default variables defined in ``vars.json``. It is recommended to use this after filtering your cohort for eligibility to speed up the process. Returns: None: Variables are loaded into the cohort. Examples: >>> # Load default variables for an ICU cohort >>> cohort = Cohort(obs_level="icu_stay", load_default_vars=False) >>> # Apply filters first (faster) >>> cohort.include_list([ ... {"variable": "age_on_admission", "operation": ">= 18", "label": "Adults"} ... ]) >>> # Then load default variables >>> cohort.load_default_vars() """ vars = loader.var_loader.load_default_variables(self) with logging_redirect_tqdm(loggers=[logger]): for var in tqdm(vars, desc="Loading default variables", unit="var"): try: self.add_variable(var) os.system("clear") except Exception as e: logger.error(f"Error {e}") logger.error(f"Error traceback: {traceback.format_exc()}") logger.info( f"Could not load variable {var.var_name}. Continuing with next variable..." ) continue
[docs] def add_variable( self, variable: str | Variable | MultiSourceVariable, save_as: str | None = None, tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, ) -> MultiSourceVariable: """Add a variable to the cohort. You may specify tmin and tmax as a tuple (e.g. ("hospital_admission", "+1d")), in which case it will be relative to the hospital admission time of the patient. Args: variable: Variable to add. Either a string with the variable name (from `vars.json`) or a Variable object. save_as: Name of the column to save the variable as. Defaults to variable name. tmin: Name of the column to use as tmin or tuple (see description). tmax: Name of the column to use as tmax or tuple (see description). Returns: Variable: The variable object. Examples: >>> cohort.add_variable("blood_sodium") >>> cohort.add_variable( ... variable="anx_dx_covid_19", ... tmin=("hospital_admission", "-1d"), ... tmax=cohort.t_eligible ... ) >>> cohort.add_variable( ... NativeStatic( ... var_name="highest_hct_before_eligible", ... select="!max value", ... base_var='blood_hematokrit', ... tmax=cohort.t_eligible ... ) ... ) >>> cohort.add_variable( ... variable='any_med_glu', ... save_as="glucose_prior_eligible", ... tmin=(cohort.t_eligible, "-48h"), ... tmax=cohort.t_eligible ... ) """ # Backup obs_backup = self._obs.clone() obs_backup_idset = set(obs_backup.select(self.primary_key).unique().to_series()) # Load if isinstance(variable, MultiSourceVariable): var = variable else: var = self._load_variable(variable, tmin, tmax) # Extract data = var.extract(self, with_source_column=var.dynamic) # Save self._save_variable( var_name=var.var_name, var_dynamic=var.dynamic, var_data=data, save_as=save_as, ) # Verify cohort integrity & restore if compromised new_obs_idset = set(self._obs.select(self.primary_key).unique().to_series()) if not new_obs_idset == obs_backup_idset: # If mismatch - restore cohort self._obs = obs_backup raise CohortDataError( "Adding variable altered cohort composition - data compromised." ) # Assert no duplicates self._validate_cohort() return var
def _load_variable( self, variable: str | Variable, tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, ) -> MultiSourceVariable: """Loads a MultiSourceVariable from JSON template or specified Variable object""" if isinstance(variable, str): tmin = tmin if tmin is not None else self.t_min tmax = tmax if tmax is not None else self.t_max return loader.var_loader.load_variable( var_name=variable, cohort=self, tmin=tmin, tmax=tmax ) else: assert ( tmin is None and tmax is None ), "Please specify tmin and tmax directly in the Variable object" source = utils.guess_variable_source(variable) or "Not applicable" return MultiSourceVariable( var_name=variable.var_name, variables={source: variable} ) def _save_variable( self, var_name: str, var_dynamic: bool, var_data: pl.DataFrame | pd.DataFrame, save_as=None, ) -> None: """Save a Variable object to the cohort. Will either add a column to obs or add an entry to obsm (for dynamic variables). Args: var: Variable to save. (Must be already extracted) save_as: Name of the column to save the variable as. Defaults to variable name. Returns: None: Variable is saved to the cohort. """ # NOTE: Won't allow "" as var name save_as = save_as or var_name data = utils.convert_to_polars_df(var_data) if var_dynamic: self._obsm[save_as] = data else: # Assert single observation per primary key assert ( data.select(self.primary_key).is_unique().all() ), f"Variable data contains duplicates for {self.primary_key} ({save_as})." # Find columns that start with save_as drop_cols = self._obs.select(pl.selectors.starts_with(f"{save_as}")).columns if drop_cols: logger.info( f"DROP existing variable(s): {', '.join(drop_cols)} (N={len(self._obs)})" ) self._obs = self._obs.drop(drop_cols) del drop_cols # Clear to avoid accidental re-use later logger.debug( f"Variable data size: {data.shape}, obs size: {self._obs.shape}" ) # Rename columns based on conditions logger.info("Renaming columns...") if ( save_as != var_name and len(data.columns) > 2 # more than one column + primary key ): # Create a mapping for renaming columns rename_map = {} for col in data.columns: if col not in [self.primary_key, save_as]: parts = col.rsplit("_", 1) if len(parts) > 1: rename_map[col] = f"{save_as}_{parts[1]}" if rename_map: data = data.rename(rename_map) else: # Simple rename of the variable column if var_name in data.columns: data = data.rename({var_name: save_as}) # Assert that only expected columns are present logger.debug("Simple rename check.") assert ( data.select(self.primary_key, pl.selectors.starts_with(save_as)).columns == data.columns ), "Variable data contains unexpected columns." # Join with self.obs logger.debug("Joining variable data with obs") self._obs = self._obs.join(data, on=self.primary_key, how="left") logger.info(f"SUCCESS: Saved variable: {save_as}") gc.collect() logger.info(f"Data summary:\n{data.describe()!r}") if self.primary_key in data.columns: nunique_info = str(data.select(pl.col(self.primary_key)).n_unique()) else: nunique_info = "Not available" logger.info(f"Number of unique non-NA entries: {nunique_info}") # Time anchors, inclusion and exclusion
[docs] def set_t_eligible(self, t_eligible: str, drop_ineligible: bool = True) -> None: """ Set the time anchor for eligibility. This can be referenced as cohort.t_eligible throughout the process and is required to add inclusion or exclusion criteria. Args: t_eligible: Name of the column to use as t_eligible. drop_ineligible: Whether to drop ineligible patients. Defaults to True. Returns: None: t_eligible is set. Examples: >>> # Add a suitable time-anchor variable >>> cohort.add_variable(NativeStatic( ... var_name="spo2_lt_90", ... base_var="spo2", ... select="!first recordtime", ... where="value < 90", ... )) >>> # Set the time anchor for eligibility >>> cohort.set_t_eligible("spo2_lt_90") """ self._assert_datetime_col(t_eligible) if self.t_eligible is not None: warnings.warn( f"WARNING: t_eligible already set to {self.t_eligible}. Will overwrite and set to {t_eligible}.\nCAVE! Previously ineligible patients will not be restored.", UserWarning, ) self.t_eligible = t_eligible if drop_ineligible: self._drop_ineligible()
[docs] def set_t_outcome(self, t_outcome: str) -> None: """ Set the time anchor for outcome. This can be referenced as cohort.t_outcome throughout the process and is recommended to specify for your study. Args: t_outcome (str): Name of the column to use as t_outcome. Returns: None: t_outcome is set. Examples: >>> cohort.set_t_outcome("hospital_discharge") """ self._assert_datetime_col(t_outcome) if self.t_outcome is not None: warnings.warn( f"WARNING: t_outcome already set to {self.t_outcome}. Will overwrite and set to {t_outcome}.", UserWarning, ) self.t_outcome = t_outcome
def _assert_datetime_col(self, col: str) -> None: assert col in self._obs.columns, f"Column {col} not found in obs." assert ( self._obs[col].dtype == pl.Datetime ), f"Column {col} is not a datetime column." def _drop_ineligible(self) -> None: """Drops ineligible observations (where t_eligible is NaT).""" mask = self._obs[self.t_eligible].is_null() # Log absolute and relative loss n_dropped = mask.sum() total = len(self._obs) logger.info( f"DROP: {n_dropped} rows ({n_dropped/total:.2%}) due to {self.t_eligible} is NaT" ) # Drop ineligible observations and update state self._obs = self._obs.remove(mask) self._change_tracker.add_step( description=f"Excluded {n_dropped} rows due to {self.t_eligible} is NaT", after_df=self._obs, )
[docs] def change_tracker( self, description: str, mode: Literal["include", "exclude"] = "include" ): """Return a context manager to group cohort edits and record a single ChangeTracker state on exit. Example: with cohort.change_tracker("Adults", mode="include") as track: track.filter(pl.col("age_on_admission") >= 18) """ return ChangeTrackerContext( pipeline=self._change_tracker, cohort=self, description=description, mode=mode, )
[docs] def include(self, *args, **kwargs) -> None: """ Add an inclusion criterion to the cohort. It is recommended to use ``Cohort.include_list()`` and add all of your inclusion criteria at once. However, if you need to specify criteria at a later stage, you can use this method. Warning: You should call ``Cohort.include_list()`` before calling ``Cohort.include()`` to ensure that the inclusion criteria are properly tracked. Args: variable (str | Variable), operation (str), label (str), operations_done (str) [Optional: tmin, tmax] Returns: None: Criterion is added to the cohort. Note: ``operation`` is passed to ``pandas.DataFrame.query``, which uses a `slightly modified Python syntax <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html>`_. Also, if you specify "true"/"True" or "false"/"False" as a value for ``operation``, it will be converted to "== True" or "== False", respectively. Examples: >>> cohort.include( ... variable="age_on_admission", ... operation=">= 18", ... label="Adult", ... operations_done="Include only adult patients" ... ) """ return self._include_exclude("include", *args, **kwargs)
[docs] def exclude(self, *args, **kwargs) -> None: """ Add an exclusion criterion to the cohort. It is recommended to use ``Cohort.exclude_list()`` and add all of your exclusion criteria at once. However, if you need to specify criteria at a later stage, you can use this method. Warning: You should call ``Cohort.exclude_list()`` before calling ``Cohort.exclude()`` to ensure that the exclusion criteria are properly tracked. Args: variable (str | Variable), operation (str), label (str), operations_done (str) [Optional: tmin, tmax] Returns: None: Criterion is added to the cohort. Note: ``operation`` is passed to ``pandas.DataFrame.query``, which uses a `slightly modified Python syntax <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html>`_. Also, if you specify "true"/"True" or "false"/"False" as a value for ``operation``, it will be converted to "== True" or "== False", respectively. Examples: >>> cohort.exclude( ... variable="elix_total", ... operation="> 20", ... operations_done="Exclude patients with high Elixhauser score" ... ) """ return self._include_exclude("exclude", *args, **kwargs)
def _include_exclude( self, mode: Literal["include", "exclude"], variable: str | Variable, operation: str, operations_done: str | None = None, label: str | None = None, allow_obs: bool = False, tmin: str | None = None, tmax: str | None = None, ) -> None: """ Add an inclusion or exclusion criterion to the cohort. Args: mode (Literal["include", "exclude"]): Whether to add an inclusion or exclusion criterion. variable (str | Variable), operation (str), operations_done (str | None), label (str | None), allow_obs (bool): Allow using a stored variable in obs insteead of re-extracting. CAVE: This can be dangerous when trying to set custom time bounds. [Optional: tmin, tmax] Returns: None: Criterion is added to the cohort. """ if len(self._obs) != self._change_tracker.current_snapshot.nuniques: warnings.warn( "WARNING: Current change tracker state is not equal to cohort size. This could indicate that the cohort was modified without proper tracking. Inclusion/Exclusion charts might be inaccurate.\n" + f"Cohort.obs: {len(self._obs)} observations / ChangeTracker: {self._change_tracker.current_snapshot.nuniques} observations", UserWarning, ) if operation.lower() in ["true", "false"]: operation = "== True" if operation.lower() == "true" else "== False" if tmax is None: assert ( self.t_eligible is not None ), "t_eligible is not set. Please set t_eligible before adding inclusion criteria or set tmax manually (not recommended)." tmax = self.t_eligible if tmin is None: tmin = "hospital_admission" # Check if this is a defined variable or a custom variable if variable in self.constant_vars or ( allow_obs and variable in self._obs.columns ): data = self._obs.select([variable, self.primary_key]) var_name = variable else: var = self._load_variable(variable, tmin, tmax) data = var.extract(self) var_name = var.var_name operations_done = operations_done or f"{var_name}{operation}" try: op_ids_expr = eval(f"pl.col('{var_name}'){operation}") op_ids = data.filter(op_ids_expr)[self.primary_key] except Exception as e: logger.error( f"Failed to add {mode} criterion ({variable}{operation}) with data:\n{data}\n{e}\n{traceback.format_exc()}" ) match mode: case "include": self._obs = self._obs.filter(pl.col(self.primary_key).is_in(op_ids)) self._change_tracker.add_step( description=f"Included {operations_done}", group=label, after_df=self._obs, ) case "exclude": self._obs = self._obs.remove(pl.col(self.primary_key).is_in(op_ids)) self._change_tracker.add_step( description=f"Excluded {operations_done}", group=label, after_df=self._obs, )
[docs] def include_list(self, inclusion_list: list[dict] = []) -> ChangeTrackerPipeline: """ Add an inclusion criteria to the cohort. Args: inclusion_list (list): List of inclusion criteria. Must include a dictionary with keys: * ``variable`` (str | Variable): Variable to use for exclusion * ``operation`` (str): Operation to apply (e.g., "> 5", "== True") * ``label`` (str): Short label for the exclusion step * ``operations_done`` (str): Detailed description of what this exclusion does * ``tmin`` (str, optional): Start time for variable extraction * ``tmax`` (str, optional): End time for variable extraction Returns: ct (CohortTracker): CohortTracker object, can be used to plot inclusion chart Note: Per default, all inclusion criteria are applied from ``tmin=cohort.tmin`` to ``tmax=cohort.t_eligible``. This is recommended to avoid introducing immortality biases. However, in some cases you might want to set custom time bounds. Examples: >>> ct = cohort.include_list([ ... { ... "variable": "age_on_admission", ... "operation": ">= 18", ... "label": "Adult patients", ... "operations_done": "Excluded patients under 18 years old" ... } ... ]) >>> ct.create_flowchart() """ for inclusion in inclusion_list: self.include(**inclusion) return self._change_tracker
[docs] def exclude_list(self, exclusion_list: list[dict] = []) -> ChangeTrackerPipeline: """ Add an exclusion criteria to the cohort. Args: exclusion_list (list): List of exclusion criteria. Each criterion is a dictionary containing: * ``variable`` (str | Variable): Variable to use for exclusion * ``operation`` (str): Operation to apply (e.g., "> 5", "== True") * ``label`` (str): Short label for the exclusion step * ``operations_done`` (str): Detailed description of what this exclusion does * ``tmin`` (str, optional): Start time for variable extraction * ``tmax`` (str, optional): End time for variable extraction Returns: ct (CohortTracker): CohortTracker object, can be used to plot exclusion chart Note: Per default, all exclusion criteria are applied from ``tmin=cohort.tmin`` to ``tmax=cohort.t_eligible``. This is recommended to avoid introducing immortality biases. However, in some cases you might want to set custom time bounds. Examples: >>> ct = cohort.exclude_list([ ... { ... "variable": "any_rrt_icu", ... "operation": "true", ... "label": "No RRT", ... "operations_done": "Excluded RRT before hypernatremia" ... }, ... { ... "variable": "any_dx_tbi", ... "operation": "true", ... "label": "No TBI", ... "operations_done": "Excluded TBI before hypernatremia" ... }, ... { ... "variable": NativeStatic( ... var_name="sodium_count", ... select="!count value", ... base_var="blood_sodium"), ... "operation": "< 1", ... "label": "Final cohort", ... "operations_done": "Excluded cases with less than 1 sodium measurement after hypernatremia", ... "tmin": cohort.t_eligible, ... "tmax": "hospital_discharge" ... } ... ]) >>> ct.create_flowchart() # Plot the exclusion flowchart """ for exclusion in exclusion_list: self.exclude(**exclusion) return self._change_tracker
[docs] def add_variable_definition(self, var_name: str, var_dict: dict) -> None: """Add or update a local variable definition. Args: var_name (str): Name of the variable. var_dict (dict): Dictionary containing variable definition. Can be partial - missing fields will be inherited from global definition. Examples: Add a completely new variable: >>> cohort.add_variable_definition("my_new_var", { ... "type": "native_dynamic", ... "table": "it_ishmed_labor", ... "where": "c_katalog_leistungtext LIKE '%new%'", ... "value_dtype": "DOUBLE", ... "cleaning": {"value": {"low": 100, "high": 150}} ... }) Partially override existing variable: >>> cohort.add_variable_definition("blood_sodium", { ... "where": "c_katalog_leistungtext LIKE '%custom_sodium%'" ... }) """ self.project_vars.setdefault(var_name, {}) self.project_vars[var_name].update(var_dict)
[docs] def add_inclusion(self) -> None: raise NotImplementedError("Use include_list instead")
[docs] def add_exclusion(self) -> None: raise NotImplementedError("Use exclude_list instead")
def _validate_cohort(self) -> None: """Validate data integrity""" # Check for unique primary keys in .obs if not self._obs.select(pl.col(self.primary_key)).is_unique().all(): raise CohortDataError( f"Duplicate entries found in obs for primary key '{self.primary_key}'." )
[docs] def to_csv(self, folder: str) -> None: """ Save the cohort to CSV files. Args: folder (str): Path to the folder where CSV files will be saved. Examples: >>> cohort.to_csv("output_data") >>> # Creates: >>> # output_data/_obs.csv >>> # output_data/blood_sodium.csv >>> # output_data/heart_rate.csv >>> # ... (one file per variable) """ # NOTE: Convert to pandas to workaround the missing support for nested data logger.info(f"Saving cohort to {folder}") os.makedirs(folder, exist_ok=True) logger.info("Preparing for export...") self._obs.to_pandas().to_csv(os.path.join(folder, "_obs.csv"), index=False) logger.info("Saving variables...") for var_name, var_data in self._obsm.items(): var_data.to_pandas().to_csv( os.path.join(folder, f"{var_name}.csv"), index=False ) logger.info("Done!")
[docs] def save(self, filename: str) -> None: """ Save the cohort to a single compressed .corr2 archive (.tar.zst equivalent). Saves cohort.__dict__ (excluding obs, obsm, variables, conn) to state.pkl, obs to obs.parquet, and each obsm DataFrame to obsm_<var_name>.parquet in a temp dir. Args: filename: Path to the .corr2 archive Returns: None """ if filename.endswith(".corr"): filename = filename + "2" warnings.warn( "The new file format is .corr2. Will use .corr2 instead of .corr.", UserWarning, ) elif "." not in filename: filename = filename + ".corr2" elif not filename.endswith(".corr2"): raise ValueError( "Unsupported file format. Please provide a filename ending with .corr2 or no file ending at all." ) self._save_corr2(filename)
def _save_corr2(self, filename: str) -> None: """ Save the cohort to a single compressed .corr2 archive (.tar.zst equivalent). Saves cohort.__dict__ (excluding obs, obsm, variables, conn) to state.pkl, obs to obs.parquet, and each obsm DataFrame to obsm_<var_name>.parquet in a temp dir Args: filename: Path to the .corr2 archive Returns: None """ # NOTE: Default path is already set for the corr_vars module by the global tempfile.tempdir with tempfile.TemporaryDirectory() as tmpdir: # Save obs and obsm as parquet files self._obs.write_parquet(os.path.join(tmpdir, "obs.parquet")) for var_name, df in self._obsm.items(): df.write_parquet(os.path.join(tmpdir, f"obsm_{var_name}.parquet")) # Prepare dict to save (exclude obs, obsm, variables, conn) state = self.__dict__.copy() state.pop("_obs", None) state.pop("_obsm", None) state.pop("obs", None) state.pop("obsm", None) state.pop("variables", None) state.pop("conn", None) # Pop password before saving to avoid unencrypted password in the archive for source in state["sources"]: if ( "conn_args" in state["sources"][source] and "password" in state["sources"][source]["conn_args"] ): state["sources"][source]["conn_args"].pop("password") # Save the rest state_path = os.path.join(tmpdir, "state.pkl") with open(state_path, "wb") as f: pickle.dump(state, f) # Tar all files in tmpdir tar_path = os.path.join(tmpdir, "archive.tar") with tarfile.open(tar_path, "w") as tar: for file in os.listdir(tmpdir): if file != "archive.tar": tar.add(os.path.join(tmpdir, file), arcname=file) # Compress tar with zstandard cctx = zstd.ZstdCompressor(level=3) with open(tar_path, "rb") as tar_in, open(filename, "wb") as out: out.write(cctx.compress(tar_in.read())) logger.info( f"SUCCESS: Saved cohort to {filename} ({len(self._obsm)} dynamic variables, {len(self._obs)} observations [{self.obs_level.lower_name}])" )
[docs] @classmethod def load(cls, filename: str) -> Cohort: if filename.endswith(".corr2"): return cls._load_corr2(filename) else: raise ValueError( "Unsupported file format. The new Cohort module only supports .corr2 files." )
@classmethod def _load_corr2(cls, filename: str, password: str | None = None) -> Cohort: """ Load a cohort from a single compressed .corr2 archive (.tar.zst equivalent). Loads cohort.__dict__ from state.pkl, obs from obs.parquet, and obsm from obsm_<var_name>.parquet files extracted from the archive. Args: filename: Path to the .corr2 archive password: Password to restore for sources with password field Returns: Cohort: The loaded cohort """ # BACKWARDS COMPATIBLITY SECTION: import io CLASS_RENAMES = { "corr_vars.utils.helpers": { "ChangeTracker": "corr_vars.core.change_tracker.ChangeTrackerPipeline" } } class _RenamingUnpickler(pickle.Unpickler): def find_class(self, module, name): if module in CLASS_RENAMES and name in CLASS_RENAMES[module]: new_path = CLASS_RENAMES[module][name] new_mod, new_cls = new_path.rsplit(".", 1) mod = __import__(new_mod, fromlist=[new_cls]) return getattr(mod, new_cls) # Default return super().find_class(module, name) def load_with_rename(data: bytes): return _RenamingUnpickler(io.BytesIO(data)).load() # LOAD FILE # NOTE: Default path is already set for the corr_vars module by the global tempfile.tempdir with tempfile.TemporaryDirectory() as tmpdir: tar_path = os.path.join(tmpdir, "archive.tar") dctx = zstd.ZstdDecompressor() with open(filename, "rb") as inp, open(tar_path, "wb") as out: out.write(dctx.decompress(inp.read())) with tarfile.open(tar_path, "r") as tar: tar.extractall(tmpdir) # Load state dict state_path = os.path.join(tmpdir, "state.pkl") with open(state_path, "rb") as f: raw = f.read() state = load_with_rename(raw) # Create instance cohort = cls.__new__(cls) cohort.__dict__.update(state) # Support older Cohort objects if getattr(cohort, "sources", None) is not None: cohort.sources = cohort._load_default_config_data(cohort.sources) if getattr(cohort, "tmpdir_manager", None) is None: cohort.tmpdir_manager = TemporaryDirectoryManager() # type: ignore[misc] if getattr(cohort, "obs_level", None) and isinstance(cohort.obs_level, str): cohort.obs_level = ObsLevel(cohort.obs_level) old_tmpdir_path = getattr(cohort, "tmpdir_path", None) # Ensure tmpdir exists and is valid tmpdir_path = getattr(cohort, "_current_tmpdir_path", old_tmpdir_path) cohort._current_tmpdir_path = cohort.tmpdir_manager.create_tmpdir( tmpdir_path=tmpdir_path, check_existing=True ) # Load parquet files obs_path = os.path.join(tmpdir, "obs.parquet") cohort._obs = pl.read_parquet(obs_path) cohort._obsm = {} for file in os.listdir(tmpdir): if file.startswith("obsm_") and file.endswith(".parquet"): var_name = file[5:-8] # remove 'obsm_' and '.parquet' cohort._obsm[var_name] = pl.read_parquet(os.path.join(tmpdir, file)) # Restore password, which was removed to avoid unencrypted password in the archive for source in state["sources"]: if "conn_args" in state["sources"][source]: state["sources"][source]["conn_args"]["password"] = password # Add flag to indicate this is a old cohort # Used in unit tests cohort._from_file = True logger.info( f"SUCCESS: Loaded cohort from {filename} ({len(cohort._obsm)} dynamic variables, {len(cohort._obs)} observations [{cohort.obs_level.lower_name}])" ) # Validate cohort._validate_cohort() return cohort # Property methods # We use _obs and _obsm to store underlying data so that the interface can be modified to pandas for the legacy version @property def obs(self) -> pl.DataFrame: return self._obs @obs.setter def obs(self, value: pl.DataFrame | pl.LazyFrame | pd.DataFrame) -> None: self._obs = utils.convert_to_polars_df(value) @obs.deleter def obs(self) -> None: warnings.warn("Can't delete Cohort.obs!", UserWarning) @property def obsm(self) -> ObsmDict: return ObsmDict(self._obsm) @property def tmpdir_path(self): self._current_tmpdir_path = self.tmpdir_manager.path return self._current_tmpdir_path # System methods def __repr__(self) -> str: return f"Cohort(obs_level={self.obs_level.lower_name}, len(obs)={len(self)})" def __str__(self) -> str: return textwrap.dedent( f"""\ Cohort object obs_level: {self.obs_level.lower_name} len(obs): {len(self)} obs (first 5 rows): {self._obs.head()} """ ) def __len__(self) -> int: return len(self._obs) def __iter__(self) -> Iterator[pl.Series]: return iter(self._obs) @overload def __getitem__(self, key: tuple[SingleIndexSelector, SingleColSelector]) -> Any: ... @overload def __getitem__( # type: ignore[overload-overlap] self, key: str | tuple[MultiIndexSelector, SingleColSelector] ) -> pl.Series: ... @overload def __getitem__( self, key: ( SingleIndexSelector | MultiIndexSelector | MultiColSelector | tuple[SingleIndexSelector, MultiColSelector] | tuple[MultiIndexSelector, MultiColSelector] ), ) -> pl.DataFrame: ... def __getitem__( self, item: ( SingleIndexSelector | SingleColSelector | MultiColSelector | MultiIndexSelector | tuple[SingleIndexSelector, SingleColSelector] | tuple[SingleIndexSelector, MultiColSelector] | tuple[MultiIndexSelector, SingleColSelector] | tuple[MultiIndexSelector, MultiColSelector] ), ) -> pl.DataFrame | pl.Series | Any: return self._obs[item] def __setitem__( self, item: ArrayLike | str | None, value: ArrayLike | None ) -> None: self._obs = self._obs.with_columns(pl.Series(item, value)) def __contains__(self, value: str) -> bool: return value in self.obs # Other print methods def _repr_html_(self) -> str: return textwrap.dedent( f"""\ <article style="margin-bottom: 2em"> <h1 style="font-size: 1.5em; font-weight: 800; text-decoration: underline solid 2px; margin-block: .5em;">Cohort object</h1> <ul style="font-family: monospace; font-size: 1em; list-style: none outside none; margin-block: .5em; padding: 0; display: flex; flex-direction: column; gap: .5em"> <li><b>obs_level:</b> {self.obs_level.lower_name}</li> <li><b>len(obs):</b> {len(self)}</li> </ul> </article> <div style="display: flex; flex-direction: column; gap: .5em"> <div style="font-family: monospace; font-size: 1em;"><b>obs (first 5 rows):</b></div> <div style="overflow-x: auto; width: 100%"> {self._obs.head()._repr_html_()} </div> </div> """ )
[docs] def debug_print(self) -> None: """Print debug information about the cohort. Please use this if you are creating a GitHub issue. Returns: None """ utils.print_cohort_debug_info(self) utils.print_debug_info()
# Utils methods
[docs] def to_stata( self, df: pd.DataFrame = None, convert_dates: dict[Hashable, str] | None = None, write_index: bool = True, to_file: str | None = None, ) -> pd.DataFrame | None: """ Convert the cohort to a Stata DataFrame. You may use cohort.stata to access the dataframe directly. Note that you need to save it to a top-level variable to access it via %%stata. Args: df (pd.DataFrame): The DataFrame to be converted to Stata format. Will default to the obs DataFrame if unspecified (default: None) convert_dates (dict[Hashable, str]): Dictionary of columns to convert to Stata date format. write_index (bool): Whether to write the index as a column. to_file (str | None): Path to save as .dta file. If left unspecified, the DataFrame will not be saved. Returns: pd.DataFrame: A Pandas Dataframe compatible with Stata if `to_file` is None. """ stata_df = df or self._obs.to_pandas() return utils.convert_to_stata( stata_df, convert_dates=convert_dates, write_index=write_index, to_file=to_file, )
stata = property(to_stata)
[docs] def to_tableone( self, ignore_cols: list | str = [], groupby: str | None = None, filter: str | None = None, pval: bool = False, normal_cols: list[str] = [], **kwargs, ) -> TableOne: """ Create a `TableOne <https://tableone.readthedocs.io/en/latest/index.html>`_ object for the cohort. Args: ignore_cols (list | str): Column(s) to ignore. groupby (str): Column to group by. filter (str): Filter to apply to the data. pval (bool): Whether to calculate p-values. normal_cols (list[str]): Columns to treat as normally distributed. **kwargs: Additional arguments to pass to TableOne. Returns: TableOne: A TableOne object. Examples: >>> tableone = cohort.tableone() >>> print(tableone) >>> tableone.to_csv("tableone.csv") >>> tableone = cohort.tableone(groupby="sex", pval=False) >>> print(tableone) >>> tableone.to_csv("tableone_sex.csv") """ tab1_df = self._obs.to_pandas() return utils.convert_to_tableone( tab1_df, ignore_cols=ignore_cols, groupby=groupby, filter=filter, pval=pval, normal_cols=normal_cols, **kwargs, )
tableone = property(to_tableone)