Source code for corr_vars.legacy_v1.cohort

# Local imports
import gzip
import pickle

from typing import Literal

import ehrapy as ep
import pandas as pd
import polars as pl

from corr_vars import logger
from corr_vars.core.cohort import Cohort as NewCohort


# Put the old cohort tracker back in for compatibility with older .corr files
class CohortTracker(ep.tl.CohortTracker):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def plot_cohort_barplot(self, *args, **kwargs):
        """We are not implementing column tracking due to unhashable column types, so this function is disabled"""
        raise NotImplementedError("This method is not yet implemented.")


class LegacyObsDataFrame(pd.DataFrame):
    """Provides functionality to access the obs dataframe as a pandas dataframe.
    Also includes a method for direct column assignment.

    Args:
        *args: Positional arguments passed to pd.DataFrame.
        cohort (Cohort, optional): Reference to the parent Cohort instance (default: None).
        **kwargs: Keyword arguments passed to pd.DataFrame.
    """

    def __init__(self, *args, cohort=None, **kwargs):
        super().__init__(*args, **kwargs)
        self._cohort = cohort

    def __setitem__(self, key, value):
        if self._cohort is None:
            raise ValueError("DataFrame is not connected to a Cohort instance")
        if hasattr(value, "__len__") and not isinstance(value, str):
            if len(value) != self.shape[0]:
                raise ValueError("Length of value does not match number of rows.")
            self._cohort._obs = self._cohort._obs.with_columns(pl.Series(key, value))
        else:
            self._cohort._obs = self._cohort._obs.with_columns(pl.lit(value).alias(key))
        super().__setitem__(key, value)


class LegacyObsmDataframe:
    """Provides functionality to access the obsm dataframe as a pandas dataframe.
    Using polars-native unnest method to increase performance over pd.json_normalize.

    Args:
        polars_df (pl.DataFrame): The polars DataFrame to wrap.
    """

    def __init__(self, polars_df: pl.DataFrame):
        self.polars_df = polars_df
        self._sync_pandas()

    def _sync_pandas(self):
        """Synchronize the internal pandas DataFrame with the polars DataFrame."""
        self.pandas_df = self.polars_df.to_pandas()

    def unnest(self, col):
        """Unnest a column in the polars DataFrame and sync with pandas.

        Args:
            col (str): Column name to unnest.

        Returns:
            LegacyObsmDataframe: Self for method chaining.
        """
        self.polars_df = self.polars_df.unnest(col)
        self._sync_pandas()
        return self  # For method chaining

    # Expose pandas DataFrame methods
    def __getattr__(self, name):
        # Delegate attribute access to pandas_df
        return getattr(self.pandas_df, name)

    def __getitem__(self, key):
        return self.pandas_df[key]

    def __setitem__(self, key, value):
        raise NotImplementedError(
            "Direct assignment to columns is not allowed. Please use .copy() to create a copy of the DataFrame to edit"
        )


class DataFrame(pd.DataFrame):
    """Extended pandas DataFrame with additional functionality.

    Args:
        *args: Positional arguments passed to pd.DataFrame.
        **kwargs: Keyword arguments passed to pd.DataFrame.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @classmethod
    def from_pandas(cls, df: pd.DataFrame):
        """Create a DataFrame from an existing pandas DataFrame.

        Args:
            df (pd.DataFrame): The pandas DataFrame to wrap.

        Returns:
            DataFrame: New instance of this class.
        """
        return cls(df)

    def unnest(self, col: str):
        """Unnest a column containing nested data.

        Args:
            col (str): Column name to unnest.

        Returns:
            DataFrame: DataFrame with unnested column.
        """
        return self.join(pd.json_normalize(self[col])).drop(columns=[col])


[docs] class Cohort(NewCohort): """Legacy class to build a cohort in the CORR database. This version uses Pandas to access .obs and .obsm attributes. Please migrate to the new version of the cohort class at your earliest convenience. Args: conn_args (dict): Dictionary of Database credentials [remote_hostname (str), username (str)] (default: {}). logger_args (dict): Dictionary of Logging configurations [level (int), file_path (str), file_mode (str), verbose_fmt (bool), colored_output (bool), formatted_numbers (bool)] (default: {}). password_file (str | bool): Path to the password file or True if your file is in ~/password.txt (default: None). database (Literal["db_hypercapnia_prepared", "db_corror_prepared"]): Database to use (default: "db_hypercapnia_prepared"). extraction_end_date (str): Deprecated as of Feb 6, 2025. This was used to set the end date for the extraction. Use filters instead (default: None). obs_level (Literal["icu_stay", "hospital_stay", "procedure"]): Observation level (default: "icu_stay"). project_vars (dict): Dictionary with local variable definitions (default: {}). merge_consecutive (bool): Whether to merge consecutive ICU stays (default: True). Does not apply to any other obs_level. load_default_vars (bool): Whether to load the default variables (default: True). filters (str): Initial filters (must be a valid SQL WHERE clause for the it_ishmed_fall table) (default: ""). Attributes: obs (pd.DataFrame): Static data for each observation. Contains one row per observation (e.g., ICU stay) with columns for static variables like demographics and outcomes. Example: >>> cohort.obs patient_id case_id icu_stay_id icu_admission icu_discharge sex ... inhospital_death 0 P001 C001 C001_1 2023-01-01 08:30:00 2023-01-03 12:00:00 M ... False 1 P001 C001 C001_2 2023-01-03 14:20:00 2023-01-05 16:30:00 M ... False 2 P002 C002 C002_1 2023-01-02 09:15:00 2023-01-04 10:30:00 F ... False 3 P003 C003 C003_1 2023-01-04 11:45:00 2023-01-07 13:20:00 F ... True ... obsm (dict of pd.DataFrame): Dynamic data stored as dictionary of DataFrames. Each DataFrame contains time-series data for a variable with columns: - recordtime: Timestamp of the measurement - value: Value of the measurement - recordtime_end: End time (only for duration-based variables like therapies) - description: Additional information (e.g., medication names) Example: >>> cohort.obsm["blood_sodium"] icu_stay_id recordtime value 0 C001_1 2023-01-01 09:30:00 138 1 C001_1 2023-01-02 10:15:00 141 2 C001_2 2023-01-03 15:00:00 137 3 C002_1 2023-01-02 10:00:00 142 4 C003_1 2023-01-04 12:30:00 139 ... variables (dict of Variable): Dictionary of all variable objects in the cohort. This is used to keep track of variable metadata. Notes: - For large cohorts, set ``load_default_vars=False`` to speed up the extraction. You can use pre-extracted cohorts as starting points and load them using ``Cohort.load()``. - Variables can be added using ``cohort.add_variable()``. Static variables will be added to ``obs``, dynamic variables to ``obsm``. - ``filters`` also allows a special shorthand "_dx" to extract the hospital admissions of the last x months, useful for debugging/prototyping. For example use "_d2" to extract every admission of the last 2 months. Examples: Create a new cohort: >>> cohort = Cohort(obs_level="icu_stay", ... database="db_hypercapnia_prepared", ... load_default_vars=False, ... password_file=True) Access static data: >>> cohort.obs["age_on_admission"] # Get age for all patients >>> cohort.obs.loc[cohort.obs["sex"] == "M"] # Filter for male patients Access time-series data: >>> cohort.obsm["blood_sodium"] # Get all blood sodium measurements >>> # Get blood sodium measurements for a specific observation >>> cohort.obsm["blood_sodium"].loc[ ... cohort.obsm["blood_sodium"][cohort.primary_key] == "12345" ... ] """
[docs] def __init__( self, conn_args: dict = {}, logger_args: dict = {}, password_file: str | bool = None, database: Literal[ "db_hypercapnia_prepared", "db_corror_prepared" ] = "db_hypercapnia_prepared", extraction_end_date: str = None, obs_level: Literal["icu_stay", "hospital_stay", "procedure"] = "icu_stay", project_vars: dict = {}, merge_consecutive: bool = True, load_default_vars: bool = True, filters: str = "", ): # Add flag to indicate this is a newly created cohort self.from_file = False self.conn_args = conn_args self.logger_args = logger_args self.database = database self.obs_level = obs_level self.merge_consecutive = merge_consecutive self.project_vars = project_vars super().__init__( logger_args=logger_args, extraction_end_date=extraction_end_date, project_vars=project_vars, load_default_vars=load_default_vars, obs_level=obs_level, sources={ "cub_hdp": { "filters": filters, "database": database, "password_file": password_file, "merge_consecutive": merge_consecutive, "conn_args": conn_args, } }, )
@property def obs(self): return LegacyObsDataFrame(self._obs.to_pandas(), cohort=self) @property def obsm(self): return {k: LegacyObsmDataframe(v) for k, v in self._obsm.items()} @obsm.setter def obsm(self, value): raise ValueError( "Cannot modify obsm in legacy version. Please migrate to the polars-native version instead" ) # Backwards compatibility add_inclusion = NewCohort.include_list add_exclusion = NewCohort.exclude_list # Save and load methods
[docs] def save(self, filename: str, legacy: bool = False): """ Save the cohort to a .corr2 archive Args: filename: Path to the .corr2 archive. """ if legacy: logger.warning( "Saving as .corr is strongly discouraged. Please use .corr2 instead." ) conn = self.conn self.conn = None # Add .corr suffix if not filename.endswith(".corr"): filename = filename + ".corr" with open(filename, "wb") as f: pickle.dump(self, f) self.conn = conn logger.info(f"SUCCESS: Saved cohort to {filename}") return if filename.endswith(".corr"): logger.warning( "Saving to .corr is deprecated. Please use .corr2 instead. This file will be saved as .corr2." ) filename = filename[:-5] + ".corr2" self.save_corr2(filename)
[docs] @classmethod def load(cls, filename: str, password_file: str = None): """ Load a cohort from a pickle file. If this file was saved by a different user, you need to pass your database credentials to the function. Args: filename: Path to the pickle file. password_file: Path to the password file. Returns: Cohort: A new Cohort object. """ if filename.endswith(".corr2"): return cls.load_corr2(filename, password_file) if filename.endswith(".gzip"): logger.warning( "Loading from .corr.gzip is deprecated. Please save your file as .corr2 in the future." ) with gzip.open(filename, "rb") as f: cohort = pickle.load(f) elif filename.endswith(".corr"): logger.warning( "Loading from .corr is deprecated. Please save your file as .corr2 in the future." ) with open(filename, "rb") as f: cohort = pickle.load(f) else: raise ValueError( f"Supported file extensions are .corr2, .corr.gzip, and .corr, you provided {filename}" ) password_file = password_file or getattr(cohort, "password_file", None) cohort.create_tmpdir(check_existing=True) cohort.from_file = True logger.info(f"SUCCESS: Loaded cohort from {filename}") return cohort
def __getstate__(self): """ Called when pickling the object. Removes the unpickleable connection. """ state = self.__dict__.copy() if "conn" in state: del state["conn"] # Avoid pickling the connection return state def __setstate__(self, state): """ Called when unpickling the object. Re-initializes the connection. """ self.__dict__.update(state) self.conn = None # Will be re-established when load() is called def _extr_end_date(self, _): raise NotImplementedError("This function has been removed as of Feb 6, 2025") def _filter_by_tmin_tmax(self, df_var, tmin, tmax): raise NotImplementedError( "This function has been removed as of May 19, 2025. It has never been in use anyway. If you need it, please create an issue on GitHub." )