Source code for corr_vars.sources.reprodicu.cohort
import polars as pl
from corr_vars import logger
from corr_vars.sources.reprodicu.config import SourceDict
from corr_vars.sources.reprodicu import helpers
from corr_vars.definitions import ObsLevel
[docs]
def load_data(obs_level, source_kwargs: SourceDict):
"""
Load data from reprodicu parquet files.
Args:
obs_level (ObsLevel): Level of observation ("icu_stay", "hospital_stay", "procedure")
source_kwargs: Dictionary containing source configuration
- path: Path to reprodicu parquet files
- exclude_datasets: List of datasets to exclude (optional)
- include_datasets: List of datasets to include (optional)
Returns:
pl.DataFrame: Loaded observation data
"""
path = source_kwargs["path"]
exclude_datasets = source_kwargs["exclude_datasets"]
include_datasets = source_kwargs["include_datasets"]
if not path:
raise ValueError(
"Path to reprodicu data files must be specified in source_kwargs['path']"
)
# Load patient information
logger.info("Loading reprodicu patient information")
ov = helpers.read_parquet(path, "patient_information")
# Handle different observation levels
match obs_level:
case ObsLevel.HOSPITAL_STAY:
raise NotImplementedError(
"Hospital stay level not implemented for reprodicu"
)
case ObsLevel.PROCEDURE:
raise NotImplementedError("Procedure level not implemented for reprodicu")
case ObsLevel.ICU_STAY:
# Default case for ICU stay level
pass
case _:
raise ValueError(f"Unsupported observation level: {obs_level}")
# Filter datasets if specified
if include_datasets:
logger.info(f"Including only datasets: {include_datasets}")
ov = ov.filter(pl.col("Source Dataset").is_in(include_datasets))
if exclude_datasets:
logger.info(f"Excluding datasets: {exclude_datasets}")
ov = ov.filter(pl.col("Source Dataset").is_in(exclude_datasets).not_())
# Column mapping from reprodicu to corr-vars standard names
column_mapping = {
"Global Person ID": "patient_id",
"Global Hospital Stay ID": "case_id",
"Global ICU Stay ID": "icu_stay_id",
"Admission Age (years)": "age_on_admission",
"Admission Height (cm)": "height_on_admission",
"Admission Weight (kg)": "weight_on_admission",
"Gender": "sex",
"Ethnicity": "ethnicity",
"Primary (Admission) Diagnosis (ICD)": "primary_diagnosis",
"Care Site": "icu_id",
"Mortality in Hospital": "inhospital_death",
"Source Dataset": "reprodicu_dataset",
"ICU Length of Stay (days)": "icu_length_of_stay",
"Hospital Length of Stay (days)": "hospital_length_of_stay",
# "icu_admission" is added later
# "icu_discharge" is added later
}
# Select and rename columns that exist in the dataset
available_columns = ov.columns
filtered_column_mapping = {
k: v for k, v in column_mapping.items() if k in available_columns
}
missing_columns = set(column_mapping.keys()) - set(ov.columns)
unused_columns = set(ov.columns) - set(column_mapping.keys())
logger.info(f"Available columns:\n{helpers.pretty_join(available_columns)}")
if not filtered_column_mapping:
raise ValueError("No expected columns found in patient_information.")
elif missing_columns:
raise ValueError(
f"Some columns were not not found in patient_information.\nUnused:\n{helpers.pretty_join(unused_columns)}\nMissing:\n{helpers.pretty_join(missing_columns)}"
)
ov = ov.select(filtered_column_mapping.keys()).rename(filtered_column_mapping)
# Apply standard transformations
first_transformations = []
last_transformations = []
# Sex transformation
if "sex" in ov.columns:
first_transformations.append(
pl.col("sex")
.cast(pl.Utf8)
.replace({"Male": "M", "Female": "F"})
.alias("sex")
)
# Boolean transformations
if "inhospital_death" in ov.columns:
first_transformations.append(
pl.col("inhospital_death").cast(pl.Boolean).alias("inhospital_death")
)
# Date/time transformations (if columns exist)
# datetime_columns = [
# "icu_admission",
# "icu_discharge",
# "hospital_admission",
# "hospital_discharge",
# ]
# for col in datetime_columns:
# if col in ov.columns:
# first_transformations.append(
# pl.col(col)
# .str.strptime(pl.Datetime, format="%Y-%m-%d %H:%M:%S", strict=False)
# .alias(col)
# )
# For reprodicu, if we don't have actual admission/discharge times,
# we need to create them based on the relative time structure
if "icu_admission" not in ov.columns:
# Create a dummy ICU admission time - this will be the reference point (time 0)
# for the relative times in the timeseries data
first_transformations.append(
pl.datetime(1900, 1, 1, 0, 0, 0).alias("icu_admission")
)
if "hospital_admission" not in ov.columns:
# Set hospital admission to be the same as ICU admission for simplicity
if "icu_admission" not in ov.columns:
# Both are missing, create dummy times
first_transformations.append(
pl.datetime(1900, 1, 1, 0, 0, 0).alias("hospital_admission")
)
else:
# ICU admission exists (either original or will be created), reference it
first_transformations.append(
pl.col("icu_admission").alias("hospital_admission")
)
if "icu_discharge" not in ov.columns:
last_transformations.append(
(
pl.col("icu_admission")
+ pl.duration(seconds=pl.col("icu_length_of_stay") * 86400)
).alias("icu_discharge")
)
if "hospital_discharge" not in ov.columns:
last_transformations.append(
(
pl.col("hospital_admission")
+ pl.duration(seconds=pl.col("hospital_length_of_stay") * 86400)
).alias("hospital_discharge")
)
if first_transformations:
ov = ov.with_columns(first_transformations)
if last_transformations:
ov = ov.with_columns(last_transformations)
logger.info(f"Loaded {len(ov)} {obs_level} observations from reprodicu")
logger.info(
f"Datasets included: {', '.join(ov.select('data_source').unique().to_series().to_list()) if 'data_source' in ov.columns else 'Unknown'}"
)
return ov