# Local imports
import gzip
import pickle
from typing import Literal
import ehrapy as ep
import pandas as pd
import polars as pl
from corr_vars import logger
from corr_vars.core.cohort import Cohort as NewCohort
# Put the old cohort tracker back in for compatibility with older .corr files
class CohortTracker(ep.tl.CohortTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def plot_cohort_barplot(self, *args, **kwargs):
"""We are not implementing column tracking due to unhashable column types, so this function is disabled"""
raise NotImplementedError("This method is not yet implemented.")
class LegacyObsDataFrame(pd.DataFrame):
"""Provides functionality to access the obs dataframe as a pandas dataframe.
Also includes a method for direct column assignment.
Args:
*args: Positional arguments passed to pd.DataFrame.
cohort (Cohort, optional): Reference to the parent Cohort instance (default: None).
**kwargs: Keyword arguments passed to pd.DataFrame.
"""
def __init__(self, *args, cohort=None, **kwargs):
super().__init__(*args, **kwargs)
self._cohort = cohort
def __setitem__(self, key, value):
if self._cohort is None:
raise ValueError("DataFrame is not connected to a Cohort instance")
if hasattr(value, "__len__") and not isinstance(value, str):
if len(value) != self.shape[0]:
raise ValueError("Length of value does not match number of rows.")
self._cohort._obs = self._cohort._obs.with_columns(pl.Series(key, value))
else:
self._cohort._obs = self._cohort._obs.with_columns(pl.lit(value).alias(key))
super().__setitem__(key, value)
class LegacyObsmDataframe:
"""Provides functionality to access the obsm dataframe as a pandas dataframe.
Using polars-native unnest method to increase performance over pd.json_normalize.
Args:
polars_df (pl.DataFrame): The polars DataFrame to wrap.
"""
def __init__(self, polars_df: pl.DataFrame):
self.polars_df = polars_df
self._sync_pandas()
def _sync_pandas(self):
"""Synchronize the internal pandas DataFrame with the polars DataFrame."""
self.pandas_df = self.polars_df.to_pandas()
def unnest(self, col):
"""Unnest a column in the polars DataFrame and sync with pandas.
Args:
col (str): Column name to unnest.
Returns:
LegacyObsmDataframe: Self for method chaining.
"""
self.polars_df = self.polars_df.unnest(col)
self._sync_pandas()
return self # For method chaining
# Expose pandas DataFrame methods
def __getattr__(self, name):
# Delegate attribute access to pandas_df
return getattr(self.pandas_df, name)
def __getitem__(self, key):
return self.pandas_df[key]
def __setitem__(self, key, value):
raise NotImplementedError(
"Direct assignment to columns is not allowed. Please use .copy() to create a copy of the DataFrame to edit"
)
class DataFrame(pd.DataFrame):
"""Extended pandas DataFrame with additional functionality.
Args:
*args: Positional arguments passed to pd.DataFrame.
**kwargs: Keyword arguments passed to pd.DataFrame.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_pandas(cls, df: pd.DataFrame):
"""Create a DataFrame from an existing pandas DataFrame.
Args:
df (pd.DataFrame): The pandas DataFrame to wrap.
Returns:
DataFrame: New instance of this class.
"""
return cls(df)
def unnest(self, col: str):
"""Unnest a column containing nested data.
Args:
col (str): Column name to unnest.
Returns:
DataFrame: DataFrame with unnested column.
"""
return self.join(pd.json_normalize(self[col])).drop(columns=[col])
[docs]
class Cohort(NewCohort):
"""Legacy class to build a cohort in the CORR database.
This version uses Pandas to access .obs and .obsm attributes.
Please migrate to the new version of the cohort class at your earliest convenience.
Args:
conn_args (dict): Dictionary of Database credentials [remote_hostname (str), username (str)] (default: {}).
logger_args (dict): Dictionary of Logging configurations [level (int), file_path (str), file_mode (str), verbose_fmt (bool), colored_output (bool), formatted_numbers (bool)] (default: {}).
password_file (str | bool): Path to the password file or True if your file is in ~/password.txt (default: None).
database (Literal["db_hypercapnia_prepared", "db_corror_prepared"]): Database to use (default: "db_hypercapnia_prepared").
extraction_end_date (str): Deprecated as of Feb 6, 2025. This was used to set the end date for the extraction. Use filters instead (default: None).
obs_level (Literal["icu_stay", "hospital_stay", "procedure"]): Observation level (default: "icu_stay").
project_vars (dict): Dictionary with local variable definitions (default: {}).
merge_consecutive (bool): Whether to merge consecutive ICU stays (default: True). Does not apply to any other obs_level.
load_default_vars (bool): Whether to load the default variables (default: True).
filters (str): Initial filters (must be a valid SQL WHERE clause for the it_ishmed_fall table) (default: "").
Attributes:
obs (pd.DataFrame):
Static data for each observation. Contains one row per observation (e.g., ICU stay)
with columns for static variables like demographics and outcomes.
Example:
>>> cohort.obs
patient_id case_id icu_stay_id icu_admission icu_discharge sex ... inhospital_death
0 P001 C001 C001_1 2023-01-01 08:30:00 2023-01-03 12:00:00 M ... False
1 P001 C001 C001_2 2023-01-03 14:20:00 2023-01-05 16:30:00 M ... False
2 P002 C002 C002_1 2023-01-02 09:15:00 2023-01-04 10:30:00 F ... False
3 P003 C003 C003_1 2023-01-04 11:45:00 2023-01-07 13:20:00 F ... True
...
obsm (dict of pd.DataFrame):
Dynamic data stored as dictionary of DataFrames. Each DataFrame contains time-series
data for a variable with columns:
- recordtime: Timestamp of the measurement
- value: Value of the measurement
- recordtime_end: End time (only for duration-based variables like therapies)
- description: Additional information (e.g., medication names)
Example:
>>> cohort.obsm["blood_sodium"]
icu_stay_id recordtime value
0 C001_1 2023-01-01 09:30:00 138
1 C001_1 2023-01-02 10:15:00 141
2 C001_2 2023-01-03 15:00:00 137
3 C002_1 2023-01-02 10:00:00 142
4 C003_1 2023-01-04 12:30:00 139
...
variables (dict of Variable):
Dictionary of all variable objects in the cohort. This is used to keep track of variable metadata.
Notes:
- For large cohorts, set ``load_default_vars=False`` to speed up the extraction. You can use
pre-extracted cohorts as starting points and load them using ``Cohort.load()``.
- Variables can be added using ``cohort.add_variable()``. Static variables will be added to ``obs``,
dynamic variables to ``obsm``.
- ``filters`` also allows a special shorthand "_dx" to extract the hospital admissions of the
last x months, useful for debugging/prototyping. For example use "_d2" to extract every admission
of the last 2 months.
Examples:
Create a new cohort:
>>> cohort = Cohort(obs_level="icu_stay",
... database="db_hypercapnia_prepared",
... load_default_vars=False,
... password_file=True)
Access static data:
>>> cohort.obs["age_on_admission"] # Get age for all patients
>>> cohort.obs.loc[cohort.obs["sex"] == "M"] # Filter for male patients
Access time-series data:
>>> cohort.obsm["blood_sodium"] # Get all blood sodium measurements
>>> # Get blood sodium measurements for a specific observation
>>> cohort.obsm["blood_sodium"].loc[
... cohort.obsm["blood_sodium"][cohort.primary_key] == "12345"
... ]
"""
[docs]
def __init__(
self,
conn_args: dict = {},
logger_args: dict = {},
password_file: str | bool = None,
database: Literal[
"db_hypercapnia_prepared", "db_corror_prepared"
] = "db_hypercapnia_prepared",
extraction_end_date: str = None,
obs_level: Literal["icu_stay", "hospital_stay", "procedure"] = "icu_stay",
project_vars: dict = {},
merge_consecutive: bool = True,
load_default_vars: bool = True,
filters: str = "",
):
# Add flag to indicate this is a newly created cohort
self.from_file = False
self.conn_args = conn_args
self.logger_args = logger_args
self.database = database
self.obs_level = obs_level
self.merge_consecutive = merge_consecutive
self.project_vars = project_vars
super().__init__(
logger_args=logger_args,
extraction_end_date=extraction_end_date,
project_vars=project_vars,
load_default_vars=load_default_vars,
obs_level=obs_level,
sources={
"cub_hdp": {
"filters": filters,
"database": database,
"password_file": password_file,
"merge_consecutive": merge_consecutive,
"conn_args": conn_args,
}
},
)
@property
def obs(self):
return LegacyObsDataFrame(self._obs.to_pandas(), cohort=self)
@property
def obsm(self):
return {k: LegacyObsmDataframe(v) for k, v in self._obsm.items()}
@obsm.setter
def obsm(self, value):
raise ValueError(
"Cannot modify obsm in legacy version. Please migrate to the polars-native version instead"
)
# Backwards compatibility
add_inclusion = NewCohort.include_list
add_exclusion = NewCohort.exclude_list
# Save and load methods
[docs]
def save(self, filename: str, legacy: bool = False):
"""
Save the cohort to a .corr2 archive
Args:
filename: Path to the .corr2 archive.
"""
if legacy:
logger.warning(
"Saving as .corr is strongly discouraged. Please use .corr2 instead."
)
conn = self.conn
self.conn = None
# Add .corr suffix
if not filename.endswith(".corr"):
filename = filename + ".corr"
with open(filename, "wb") as f:
pickle.dump(self, f)
self.conn = conn
logger.info(f"SUCCESS: Saved cohort to {filename}")
return
if filename.endswith(".corr"):
logger.warning(
"Saving to .corr is deprecated. Please use .corr2 instead. This file will be saved as .corr2."
)
filename = filename[:-5] + ".corr2"
self.save_corr2(filename)
[docs]
@classmethod
def load(cls, filename: str, password_file: str = None):
"""
Load a cohort from a pickle file. If this file was saved by a different user, you need to pass your database credentials to the function.
Args:
filename: Path to the pickle file.
password_file: Path to the password file.
Returns:
Cohort: A new Cohort object.
"""
if filename.endswith(".corr2"):
return cls.load_corr2(filename, password_file)
if filename.endswith(".gzip"):
logger.warning(
"Loading from .corr.gzip is deprecated. Please save your file as .corr2 in the future."
)
with gzip.open(filename, "rb") as f:
cohort = pickle.load(f)
elif filename.endswith(".corr"):
logger.warning(
"Loading from .corr is deprecated. Please save your file as .corr2 in the future."
)
with open(filename, "rb") as f:
cohort = pickle.load(f)
else:
raise ValueError(
f"Supported file extensions are .corr2, .corr.gzip, and .corr, you provided {filename}"
)
password_file = password_file or getattr(cohort, "password_file", None)
cohort.create_tmpdir(check_existing=True)
cohort.from_file = True
logger.info(f"SUCCESS: Loaded cohort from {filename}")
return cohort
def __getstate__(self):
"""
Called when pickling the object. Removes the unpickleable connection.
"""
state = self.__dict__.copy()
if "conn" in state:
del state["conn"] # Avoid pickling the connection
return state
def __setstate__(self, state):
"""
Called when unpickling the object. Re-initializes the connection.
"""
self.__dict__.update(state)
self.conn = None # Will be re-established when load() is called
def _extr_end_date(self, _):
raise NotImplementedError("This function has been removed as of Feb 6, 2025")
def _filter_by_tmin_tmax(self, df_var, tmin, tmax):
raise NotImplementedError(
"This function has been removed as of May 19, 2025. It has never been in use anyway. If you need it, please create an issue on GitHub."
)