from __future__ import annotations
from copy import deepcopy
import inspect
import warnings
import logging
import numpy as np
import polars as pl
import pandas as pd
from pandas.core.groupby import DataFrameGroupBy
from tqdm import tqdm
from corr_vars import logger
import corr_vars.sources as sources
from corr_vars.utils.logging import log_collection
from collections.abc import Callable, Iterable, Mapping, MutableMapping
from corr_vars.definitions import (
TimeBoundColumn,
IntoExprColumn,
PandasDataframe,
PolarsDataframe,
PolarsFrame,
)
from typing import Literal, Any, overload, TYPE_CHECKING, TypeVar, cast
if TYPE_CHECKING:
from corr_vars.core.variable import Variable
T = TypeVar("T", bound=MutableMapping[str, Any])
# Unused.
[docs]
def filter_by_condition(
df: pl.DataFrame,
expression: (
IntoExprColumn
| Iterable[IntoExprColumn]
| bool
| list[bool]
| np.ndarray[Any, Any]
),
description: str = "filter_by_condition call",
verbose: bool = True,
mode: Literal["drop", "keep"] = "drop",
) -> pd.DataFrame:
"""
Drop rows from a DataFrame based on a condition function.
Parameters:
df (pl.DataFrame): The input DataFrame.
expression: Expression(s) that evaluate to a boolean Series.
description (str): Description of the condition.
mode (str): Whether to drop or keep the rows. Can be "drop" or "keep".
Returns:
pd.DataFrame: The DataFrame with rows dropped based on the condition.
"""
mask = df.select(expression)
if any(dtype != pl.Boolean for dtype in mask.schema.values()):
raise ValueError(
"Conditional expression must return a boolean Series/Dataframe."
)
if mask.null_count().select(pl.any_horizontal(pl.all().gt(0))).item():
warnings.warn(
"Condition expression should return a boolean Series/Dataframe without NA values. Consider the use of `fillna`.",
UserWarning,
)
if mode == "drop":
df2 = df.remove(expression)
elif mode == "keep":
df2 = df.filter(expression)
else:
raise ValueError(f"Mode must be 'drop' or 'keep'. Got {mode}")
if verbose:
true_count = mask.sum().sum_horizontal().item()
logger.info(
f"{mode.upper()}: {true_count} rows ({true_count/len(df):.2%}) due to {description}"
)
return df2
# Unused. Written for pandas.
[docs]
def df_find_closest(
group: pd.DataFrame,
to_col: str,
tdelta: str = "0",
pm_before: str = "52w",
pm_after: str = "52w",
) -> pd.Series[Any] | pd.DataFrame:
"""Find the closest record to the target time.
Args:
group (pd.DataFrame): The input dataframe.
to_col (str): The column containing the target time.
tdelta (str): The time delta. Defaults to "0".
pm_before (str): The time range before the target time. Defaults to "52w".
pm_after (str): The time range after the target time. Defaults to "52w".
Returns:
pd.Series[Any] | pd.DataFrame: The closest record.
"""
tdelta = pd.to_timedelta(tdelta)
pm_before = pd.to_timedelta(pm_before)
pm_after = pd.to_timedelta(pm_after)
target_time = group[to_col] + tdelta
group["timediff"] = (group["recordtime"] - target_time).abs()
within_range = group[
(group["recordtime"] >= (target_time - pm_before))
& (group["recordtime"] <= (target_time + pm_after))
]
if within_range.empty:
return pd.Series([np.nan] * len(group.columns), index=group.columns)
closest = within_range.loc[within_range["timediff"].idxmin()]
closest.drop(columns=["timediff"], inplace=True)
return closest
# Unused. Written for pandas.
[docs]
def aggregate_column(
group: DataFrameGroupBy, agg_func: str, column: str, params: list[str] = []
) -> pd.Series[Any]:
match agg_func:
case "median":
return group[column].median()
case "mean":
return group[column].mean()
case "perc":
return group[column].quantile(float(params[0]) / 100)
case "count":
return group.size()
case _:
raise ValueError(f"Unknown aggregation function: {agg_func}")
@overload
def get_cb(
df: PolarsDataframe,
join_key: str = ...,
primary_key: str = ...,
tmin_col: str = ...,
tmax_col: str = ...,
ttarget_col: str | None = ...,
) -> PolarsDataframe:
...
@overload
def get_cb(
df: PandasDataframe,
join_key: str = ...,
primary_key: str = ...,
tmin_col: str = ...,
tmax_col: str = ...,
ttarget_col: str | None = ...,
) -> PandasDataframe:
...
[docs]
def get_cb(
df: PandasDataframe | PolarsDataframe,
join_key: str = "case_id",
primary_key: str = "case_id",
tmin_col: str = "tmin",
tmax_col: str = "tmax",
ttarget_col: str | None = None,
) -> PandasDataframe | PolarsDataframe:
"""Get the case bounds Dataframe. This is used to apply a time filter to a native dynamic variable.
Args:
df (pd.DataFrame | pl.DataFrame): The cohort dataframe.
join_key (str): The join key column. Defaults to "case_id".
primary_key (str): The primary key column. Defaults to "case_id".
tmin_col (str): The column containing the tmin. Defaults to "tmin".
tmax_col (str): The column containing the tmax. Defaults to "tmax".
ttarget_col (str | None): The column containing the target time. Defaults to None. (For !closest)
Returns:
result: The case bounds Dataframe with the same type as `df`.
"""
if isinstance(df, pl.DataFrame):
# Drop rows with NaN tmin or tmax
df = df.filter(pl.col(tmin_col).is_not_null() & pl.col(tmax_col).is_not_null())
# Build columns list
columns = [join_key, tmin_col, tmax_col]
if ttarget_col:
columns.append(ttarget_col)
if primary_key != join_key:
columns.append(primary_key)
# Drop case_tmin and case_tmax columns if they exist
df = df.drop(["case_tmin", "case_tmax"], strict=False)
# Select columns and rename
cb = df.select(columns).rename({tmin_col: "case_tmin", tmax_col: "case_tmax"})
return cb
else:
# Original pandas implementation
# Drop rows with NaN tmin or tmax
df = df.dropna(subset=[tmin_col, tmax_col])
# Build columns list and format string
columns = [join_key, tmin_col, tmax_col]
if ttarget_col:
columns.append(ttarget_col)
if primary_key != join_key:
columns.append(primary_key)
cb = df[columns].rename(columns={tmin_col: "case_tmin", tmax_col: "case_tmax"})
return cb
def _get_time_series(obs: pd.DataFrame, col_name: str, tdelta: str = "0") -> pd.Series:
"""Returns a series with the time since col_name (in datetime format) plus tdelta (in pd.Timedelta format).
Args:
obs (pd.DataFrame): The observation dataframe.
col_name (str): The column name.
tdelta (str): The time delta. Defaults to "0".
Returns:
pd.Series: A series with the time column + tdelta.
"""
tdelta = pd.to_timedelta(tdelta)
return obs[col_name] + tdelta
@overload
def pl_parse_time_args(
obs: PolarsFrame,
tmin: TimeBoundColumn | None = ...,
tmax: TimeBoundColumn | None = ...,
) -> PolarsFrame:
...
@overload
def pl_parse_time_args(
obs: PandasDataframe,
tmin: TimeBoundColumn | None = ...,
tmax: TimeBoundColumn | None = ...,
) -> PandasDataframe:
...
[docs]
def pl_parse_time_args(
obs: PolarsFrame | PandasDataframe,
tmin: TimeBoundColumn | None = None,
tmax: TimeBoundColumn | None = None,
) -> PolarsFrame | PandasDataframe:
"""Parse the time arguments to get the time series.
Args:
obs (pd.DataFrame | pl.DataFrame): The observation dataframe.
tmin (TimeBoundColumn | None): The tmin argument. Defaults to None.
tmax (TimeBoundColumn | None): The tmax argument. Defaults to None.
Returns:
result: The observation dataframe with the same type as obs with tmin and tmax columns.
"""
if isinstance(obs, pl.DataFrame | pl.LazyFrame):
# Drop rows where tmin or tmax is in year 9999
if tmin:
col = tmin[0] if isinstance(tmin, tuple) else tmin
obs = obs.filter(pl.col(col).dt.year() != 9999)
if tmax:
col = tmax[0] if isinstance(tmax, tuple) else tmax
obs = obs.filter(pl.col(col).dt.year() != 9999)
if isinstance(tmin, tuple):
col_name, tdelta = tmin
# Convert string to duration and add to column
tmin_expr = pl.col(col_name) + pl.duration(
seconds=pd.to_timedelta(tdelta).total_seconds()
)
elif tmin:
tmin_expr = pl.col(tmin)
else:
tmin_expr = pl.lit(None)
if isinstance(tmax, tuple):
col_name, tdelta = tmax
# Convert string to duration and add to column
tmax_expr = pl.col(col_name) + pl.duration(
seconds=pd.to_timedelta(tdelta).total_seconds()
)
elif tmax:
tmax_expr = pl.col(tmax)
else:
tmax_expr = pl.lit(None)
obs = obs.with_columns([tmin_expr.alias("tmin"), tmax_expr.alias("tmax")])
return obs
else:
# Drop rows where tmin or tmax is in year 9999
if tmin:
col = tmin[0] if isinstance(tmin, tuple) else tmin
obs = obs[obs[col].dt.year != 9999]
if tmax:
col = tmax[0] if isinstance(tmax, tuple) else tmax
obs = obs[obs[col].dt.year != 9999]
tmin_series: pd.Series | float
if isinstance(tmin, tuple):
col_name, tdelta = tmin
tmin_series = _get_time_series(obs, col_name, tdelta=tdelta)
elif tmin:
tmin_series = _get_time_series(obs, col_name=tmin)
else:
tmin_series = np.nan
tmax_series: pd.Series | float
if isinstance(tmax, tuple):
col_name, tdelta = tmax
tmax_series = _get_time_series(obs, col_name, tdelta=tdelta)
elif tmax:
tmax_series = _get_time_series(obs, col_name=tmax)
else:
tmax_series = np.nan
obs["tmin"] = tmin_series
obs["tmax"] = tmax_series
return obs
# Unused. Written for pandas.
[docs]
def merge_consecutive(
data: pd.DataFrame,
primary_key: str, # Deprecated!
recordtime: str = "recordtime",
recordtime_end: str = "recordtime_end",
time_threshold: pd.Timedelta = pd.Timedelta(minutes=30),
) -> pd.DataFrame:
"""
Combine consecutive sessions (<30min separation) of ecmo_vv_icu into a single session.
Args:
data (pd.DataFrame): The data to merge.
primary_key (str): The primary key column. Deprecated! Will not be used, case_id will be used instead.
recordtime (str): The recordtime column. Defaults to "recordtime".
recordtime_end (str): The recordtime_end column. Defaults to "recordtime_end".
time_threshold (pd.Timedelta): The time threshold. Defaults to 30 minutes.
Returns:
pd.DataFrame: The merged data.
"""
# Sort by primary_key and recordtime
data = data.sort_values(by=["case_id", recordtime]).reset_index(drop=True)
def combine_consecutive_sessions(group):
group["last_event_recordtime_end"] = group["recordtime_end"].shift(1)
group["time_since_last_event"] = (
group["recordtime"] - group["last_event_recordtime_end"]
)
group["merge_with_previous"] = group["time_since_last_event"] < time_threshold
group["merge_group"] = (~group["merge_with_previous"]).cumsum()
gr_merged = (
group.groupby(by=["case_id", "merge_group"])
.agg({"value": "first", "recordtime": "first", "recordtime_end": "last"})
.reset_index()
)
return gr_merged
grouped = data.groupby("case_id")
grouped = grouped.apply(combine_consecutive_sessions)
return grouped.reset_index(drop=True)
[docs]
def interval_bucket_agg(
data: PolarsFrame,
primary_key: str = "case_id",
recordtime: str = "recordtime",
recordtime_end: str | None = None,
recordtime_end_gap_treshold: str | None = None,
recordtime_end_gap_behavior: Literal["impute", "drop"] = "drop",
duration_empty_impute_duration: str | None = None,
value: str = "value",
tmin: str | None = None,
tmax: str | None = None,
t0: str | None = None,
t_missing_behavior: Literal["keep", "drop"] = "drop",
granularity: str | None = "1h",
same_bin_aggregation: Callable[[pl.Expr], pl.Expr] | None = None,
) -> PolarsFrame:
"""
Bins intervals clipped between tmin and tmax into bins aligned to t0 and aggregates on these bins.
Args:
data (pl.DataFrame | pl.LazyFrame): The data to transform.
primary_key (str): The primary key column. Defaults to "case_id".
recordtime (str): The recordtime column. Defaults to "recordtime".
recordtime_end (str | None): The recordtime_end column. Defaults to `None`. Will impute recordtime_end from the next recordtime if set to `None`.
recordtime_end_gap_treshold (str | None): Maximum allowed gap for recordtime_end. Only applies if recordtime_end is set to `None`.
recordtime_end_gap_behavior (str | None): Determines whether to impute or drop rows when the `recordtime_end_gap_treshold` is exceeded.
duration_empty_impute_duration (str | None): The duration to be imputed by offsetting recordtime_end. Drops rows if set to default of `None` or ≤0s.
value (str): The value column. Defaults to "value".
tmin (str | None): The tmin column. Defaults to `None`.
tmax (str | None): The tmax column. Defaults to `None`.
t0 (str | None): The t0 column. Bins will align to this datetime col if specified. Defaults to `None`.
t_missing_behavior (str | None): Determines whether to drop or keep rows when `t0`, `tmin` or `tmax` columns are specified and values are missing in `data`. Defaults to "drop".
granularity (str | None): The size of the bucket intervals. Defaults to 1 hour.
same_bin_aggregation: (Callable[[pl.Expr], pl.Expr] | None): The aggregation to perform on `value` for each bin. Uses `pl.sum()` if set to `None`.
.. Note::
The `granularity`, `recordtime_end_impute_gap_treshold` and `duration_empty_impute_duration` arguments are created with
the following string language:
- 1ns (1 nanosecond) # nanoseconds are not supported by our DataFrames.
- 1us (1 microsecond)
- 1ms (1 millisecond)
- 1s (1 second)
- 1m (1 minute)
- 1h (1 hour)
- 1d (1 calendar day)
- 1w (1 calendar week)
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
Or combine them: "3d12h4m25s" # 3 days, 12 hours, 4 minutes, and 25 seconds
By "calendar day", we mean the corresponding time on the next day (which may
not be 24 hours, due to daylight savings). Similarly for "calendar week",
"calendar month", "calendar quarter", and "calendar year".
If you do not wish to use "calendar day" or any other "calendar" durations, specify the duration in hours or smaller units instead.
Returns:
clip_bin_data (pl.DataFrame | pl.LazyFrame): The clipped and binned data. Same type as `data`.
"""
lazy_input = data if isinstance(data, pl.LazyFrame) else data.lazy()
START = "start_time"
END = "end_time"
DURATION = "duration_time"
VALUE = "value"
TMIN = "tmin"
TMAX = "tmax"
T0 = "t0"
START_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{START}"
END_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{END}"
DURATION_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{DURATION}"
VALUE_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{VALUE}"
START_SPLIT = f"split_{START}"
END_SPLIT = f"split_{END}"
START_OVERLAP = f"overlap_{START}"
END_OVERLAP = f"overlap_{END}"
DURATION_OVERLAP = f"overlap_{DURATION}"
VALUE_OVERLAP = f"overlap_{VALUE}"
BEFORE_OR_AFTER_T0_SIGNED = "t0_side_signed"
T0_GRANULARITY_DISTANCE = "t0_granularity_distance"
AGG_START_OVERLAP = f"first_{START_OVERLAP}"
AGG_END_OVERLAP = f"last_{END_OVERLAP}"
AGG_DURATION_OVERLAP_S = f"max_{DURATION_OVERLAP}_s"
AGG_VALUE = f"agg_{VALUE}"
ATTRIBUTES = "attributes"
# Select required columns from input and rename them
# Fill recordtime_end, tmin, tmax and t0 with None if unspecified
lf = lazy_input.select(
primary_key,
pl.col(recordtime).alias(START),
(
pl.lit(None).alias(END)
if recordtime_end is None
else pl.col(recordtime_end).alias(END)
),
pl.col(value).alias(VALUE),
pl.lit(None).alias(TMIN) if tmin is None else pl.col(tmin).alias(TMIN),
pl.lit(None).alias(TMAX) if tmax is None else pl.col(tmax).alias(TMAX),
pl.lit(None).alias(T0) if t0 is None else pl.col(t0).alias(T0),
)
# impute recordtime_end if necessary
if recordtime_end is None:
# Sort and calculate end_time by shifting over partition
lf = lf.sort(primary_key, START).with_columns(
# NOTE: Create nulls for the last entry of each partition
# Will be dropped in the next step
pl.col(START)
.shift(-1)
.over(primary_key)
.alias(END)
)
# Enforce max gap treshold
if recordtime_end_gap_treshold:
if recordtime_end_gap_behavior == "drop":
lf = lf.remove(
pl.col(START)
.dt.offset_by(recordtime_end_gap_treshold)
.lt(pl.col(END))
)
else:
lf = lf.with_columns(
pl.min_horizontal(
pl.col(START).dt.offset_by(recordtime_end_gap_treshold),
pl.col(END),
).alias(END)
)
# Drop nulls and calculate duration
lf = lf.drop_nulls(
[primary_key, START, END, VALUE]
+ ([TMIN] if tmin is not None and t_missing_behavior == "drop" else [])
+ ([TMAX] if tmax is not None and t_missing_behavior == "drop" else [])
+ ([T0] if t0 is not None and t_missing_behavior == "drop" else [])
).with_columns(pl.col(END).sub(pl.col(START)).alias(DURATION))
if duration_empty_impute_duration:
lf = lf.with_columns(
pl.when(pl.col(DURATION).eq(0))
.then(pl.col(END).dt.offset_by(duration_empty_impute_duration))
.otherwise(pl.col(END))
.alias(END)
)
# Filter out empty intervals and intervals not intersecting [tmin, tmax]
lf_clipped = (
lf.filter(
pl.col(DURATION).gt(0)
& pl.col(START).lt(pl.col(TMAX)).or_(pl.col(TMAX).is_null())
& pl.col(END).gt(pl.col(TMIN)).or_(pl.col(TMIN).is_null())
)
.with_columns(
pl.col(START).clip(TMIN, TMAX).alias(START_CLIPPED_TMIN_TMAX),
pl.col(END).clip(TMIN, TMAX).alias(END_CLIPPED_TMIN_TMAX),
)
.drop(TMIN, TMAX)
.with_columns(
pl.col(END_CLIPPED_TMIN_TMAX)
.sub(pl.col(START_CLIPPED_TMIN_TMAX))
.alias(DURATION_CLIPPED_TMIN_TMAX)
)
)
# Calculate corrected values for intervals clipped to [tmin, tmax]
# Skip this step when tmin or tmax unspecified for performance
# It is practically a no-op in that case
corrected_value_expr: pl.Expr
if tmin is not None or tmax is not None:
corrected_value_expr = (
pl.col(VALUE)
.cast(pl.Float64)
.mul(pl.col(DURATION_CLIPPED_TMIN_TMAX))
.truediv(pl.col(DURATION))
)
else:
corrected_value_expr = pl.col(VALUE).cast(pl.Float64)
lf_clipped = (
lf_clipped
# Shouldn't be possible after the previous filter step
.filter(pl.col(DURATION_CLIPPED_TMIN_TMAX).gt(0)).with_columns(
corrected_value_expr.alias(VALUE_CLIPPED_TMIN_TMAX)
)
)
# NOTE: Empty string is also not valid
correction_expr: pl.Expr
datetime_ranges_expr: pl.Expr
intersection_condition_expr: pl.Expr
lf_distributed: pl.LazyFrame
if granularity:
# Use t0 as relative anchor for the buckets
if t0 is not None:
# Apply correction offset to align bins to a grid starting from t0
correction_expr = pl.col(T0).sub(pl.col(T0).dt.truncate(granularity))
datetime_ranges_expr = pl.datetime_ranges(
pl.col(START_CLIPPED_TMIN_TMAX)
.sub(correction_expr)
.dt.truncate(granularity),
pl.col(END_CLIPPED_TMIN_TMAX).sub(correction_expr),
granularity,
)
lf_distributed = (
# Generate buckets
lf_clipped.with_columns(
datetime_ranges_expr.alias(START_SPLIT)
).explode(START_SPLIT)
# Revert correction offset
.with_columns(
pl.col(START_SPLIT).add(correction_expr),
)
)
# Snap buckets to previous grid point
else:
datetime_ranges_expr = pl.datetime_ranges(
pl.col(START_CLIPPED_TMIN_TMAX).dt.truncate(granularity),
END_CLIPPED_TMIN_TMAX,
granularity,
)
lf_distributed = (
# Generate buckets
lf_clipped.with_columns(
datetime_ranges_expr.alias(START_SPLIT)
).explode(START_SPLIT)
)
lf_distributed = (
lf_distributed
# Calculate end_granularity
.with_columns(
pl.col(START_SPLIT).dt.offset_by(granularity).alias(END_SPLIT),
)
)
else:
# Split intervals containing t0 into two intervals
if t0 is not None:
# Find intervals containing t0
intersection_condition_expr = (
pl.col(START_CLIPPED_TMIN_TMAX)
.lt(pl.col(T0))
.and_(pl.col(END_CLIPPED_TMIN_TMAX).gt(pl.col(T0)))
)
# Split intervals containing t0
lf_distributed = lf_clipped.with_columns(
pl.when(intersection_condition_expr)
.then(
pl.concat_list(
pl.col(START_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime)),
pl.col(T0).cast(pl.List(pl.Datetime)),
)
)
.otherwise(pl.col(START_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime)))
.alias(START_SPLIT),
pl.when(intersection_condition_expr)
.then(
pl.concat_list(
pl.col(T0).cast(pl.List(pl.Datetime)),
pl.col(END_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime)),
)
)
.otherwise(pl.col(END_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime)))
.alias(END_SPLIT),
).explode(START_SPLIT, END_SPLIT)
# Do nothing: Split times are equal to previous times and not exploded
else:
# Use clipped times as bucket times
lf_distributed = (
# Generate buckets
lf_clipped.with_columns(
pl.col(START_CLIPPED_TMIN_TMAX).alias(START_SPLIT),
pl.col(END_CLIPPED_TMIN_TMAX).alias(END_SPLIT),
)
)
# Calculate corrected values for intervals clipped to [tmin, tmax]
# Skip this step when tmin or tmax unspecified for performance
# It is practically a no-op in that case
if granularity or t0 is not None:
corrected_value_expr = (
pl.col(VALUE_CLIPPED_TMIN_TMAX) # Already cast to Float64
.mul(pl.col(DURATION_OVERLAP))
.truediv(pl.col(DURATION_CLIPPED_TMIN_TMAX))
)
else:
corrected_value_expr = pl.col(VALUE_CLIPPED_TMIN_TMAX)
lf_distributed = (
lf_distributed.with_columns(
pl.max_horizontal(START_CLIPPED_TMIN_TMAX, START_SPLIT).alias(
START_OVERLAP
),
pl.min_horizontal(END_CLIPPED_TMIN_TMAX, END_SPLIT).alias(END_OVERLAP),
)
.with_columns(
pl.col(END_OVERLAP).sub(pl.col(START_OVERLAP)).alias(DURATION_OVERLAP)
)
.with_columns(corrected_value_expr.alias(VALUE_OVERLAP))
.select(
primary_key,
START_SPLIT,
END_SPLIT,
T0,
START_OVERLAP,
END_OVERLAP,
VALUE_OVERLAP,
)
)
# Importantly this data is not yet aggregated on the intervals...
# There can be multiple entries per interval
# Aggregate interval
def agg_interval(
df: pl.LazyFrame, more_group_by_cols: Iterable[pl.Expr | str] = []
):
return (
df.group_by(primary_key, *more_group_by_cols)
.agg(
pl.min(START_OVERLAP).alias(AGG_START_OVERLAP),
pl.max(END_OVERLAP).alias(AGG_END_OVERLAP),
(
pl.col(VALUE_OVERLAP).pipe(same_bin_aggregation).alias(AGG_VALUE)
if same_bin_aggregation
else pl.col(VALUE_OVERLAP).sum().alias(AGG_VALUE)
),
)
# Recalculate overlap_duration_s with new aggregated timestamps
.with_columns(
pl.col(AGG_END_OVERLAP)
.sub(pl.col(AGG_START_OVERLAP))
.dt.total_seconds()
.alias(AGG_DURATION_OVERLAP_S)
)
.sort(primary_key, AGG_START_OVERLAP)
)
group_cols: list[str | pl.Expr] = []
if granularity:
group_cols += [START_SPLIT, END_SPLIT] + ([T0] if t0 is not None else [])
elif t0 is not None:
group_cols += [
pl.when(pl.col(START_SPLIT).lt(pl.col(T0)))
.then(pl.lit(-1))
.otherwise(pl.lit(1))
.alias(BEFORE_OR_AFTER_T0_SIGNED),
]
agg_df = lf_distributed.pipe(agg_interval, group_cols)
extra_attributes: list[pl.Expr | str] = [
AGG_START_OVERLAP,
AGG_END_OVERLAP,
AGG_DURATION_OVERLAP_S,
]
if granularity and t0 is not None:
extra_attributes += [
pl.col(START_SPLIT)
.sub(pl.col(T0))
.truediv(pl.col(END_SPLIT).sub(pl.col(START_SPLIT)))
.alias(T0_GRANULARITY_DISTANCE)
]
elif t0 is not None:
extra_attributes += [BEFORE_OR_AFTER_T0_SIGNED]
agg_df_packed = agg_df.with_columns(
pl.struct(extra_attributes).alias(ATTRIBUTES)
).select(primary_key, *group_cols, AGG_VALUE, ATTRIBUTES)
return agg_df_packed if isinstance(data, pl.LazyFrame) else agg_df_packed.collect()
@overload
def clean_timeseries(
data: PolarsDataframe,
max_roc: float,
time_col: str = ...,
value_col: str = ...,
identifier_col: str = ...,
) -> PolarsDataframe:
...
@overload
def clean_timeseries(
data: PandasDataframe,
max_roc: float,
time_col: str = ...,
value_col: str = ...,
identifier_col: str = ...,
) -> PandasDataframe:
...
[docs]
def clean_timeseries(
data: PandasDataframe | PolarsDataframe,
max_roc: float,
time_col: str = "recordtime",
value_col: str = "value",
identifier_col: str = "icu_stay_id",
) -> PandasDataframe | PolarsDataframe:
"""
Clean timeseries by removing values that create impossible rates of change
using vectorized operations with Polars for maximum performance.
Args:
data (pl.DataFrame): The data to clean.
max_roc (float): The maximum rate of change.
time_col (str): The time column.
value_col (str): The value column.
identifier_col (str): The identifier column.
Returns:
result: The cleaned data with the same type as `data`.
"""
return_pandas = False
# Convert to Polars if necessary
if isinstance(data, pd.DataFrame):
data = pl.from_pandas(data)
return_pandas = True
def clean_group(group):
if len(group) <= 1:
return group.with_columns(pl.lit(None).alias("current_roc"))
group = group.sort(time_col)
times = group.select(time_col).to_numpy().flatten()
values = group.select(value_col).to_numpy().flatten()
keep_mask = np.ones(len(group), dtype=bool)
i = 1
while i < len(group):
last_kept_idx = np.where(keep_mask[:i])[0][-1]
time_diff = (
times[i] - times[last_kept_idx]
) / 3600 # Convert seconds to hours
value_diff = values[i] - values[last_kept_idx]
current_roc = abs(value_diff / time_diff) if time_diff > 0 else float("inf")
if np.isnan(current_roc) or current_roc >= max_roc:
keep_mask[i] = False
i += 1
filtered_group = group.filter(pl.lit(keep_mask))
# Calculate RoC for kept values
if len(filtered_group) <= 1:
return filtered_group.with_columns(pl.lit(None).alias("current_roc"))
return filtered_group.with_columns(
(
(pl.col("value") - pl.col("value").shift(1))
/ ((pl.col(time_col) - pl.col(time_col).shift(1)) / 3600)
).alias("current_roc")
)
result = data.partition_by(identifier_col)
cleaned_groups = [
clean_group(group) for group in tqdm(result, desc="Cleaning timeseries")
]
res = (
pl.concat(cleaned_groups, how="vertical_relaxed")
if cleaned_groups
else data.clear()
)
return res.to_pandas() if return_pandas else res
[docs]
def harmonize_str_list_cols(dfs: Iterable[pl.DataFrame]) -> list[pl.DataFrame]:
cols = set().union(*(df.columns for df in dfs))
if not cols:
raise ValueError("All dataframes are empty.")
harmonized = []
for col in cols:
types = {df.schema[col] for df in dfs if col in df.columns}
if pl.String in types and pl.List(pl.String) in types:
harmonized = [
df.with_columns(
pl.col(col).cast(pl.List(pl.String), strict=False).alias(col)
if col in df.columns
else []
)
for df in dfs
]
return harmonized or list(dfs)
[docs]
def guess_variable_source(variable: Variable) -> str | None:
"""Guess the source of a variable."""
variable_module = inspect.getmodule(variable)
if variable_module is None:
return None
prefix = f"{sources.__name__}."
variable_module_name = variable_module.__name__
if not variable_module_name.startswith(prefix):
return None
source = variable_module_name.removeprefix(prefix).split(".", maxsplit=1)[0]
return source
[docs]
def find_unknown_keys(
user_cfg: Mapping[str, Any],
schema: Mapping[str, Any],
prefix: str = "",
warnings: list[str] | None = None,
) -> list[str]:
if warnings is None:
warnings = []
for key, value in user_cfg.items():
path = f"{prefix}.{key}" if prefix else key
if key not in schema:
warnings.append(f"Unknown key in config: {path}")
continue
if isinstance(value, Mapping) and isinstance(schema.get(key), Mapping):
find_unknown_keys(value, schema[key], path, warnings)
return warnings
[docs]
def deep_merge(base: T, override: Mapping[str, Any]) -> T:
"""
Recursively merge `override` into `base` with override taking precedence.
Works on normal dictionaries, compatible with TypedDict runtime dicts.
"""
for key, value in override.items():
if (
key in base
and isinstance(base[key], MutableMapping)
and isinstance(value, Mapping)
):
deep_merge(base[key], value)
else:
base[key] = value
return base
[docs]
def load_source_defaults(
source_kwargs: Mapping[str, Any],
defaults: T,
migrations: Mapping[str, Any] | None = None,
) -> T:
# 1. Load source config
config = dict(source_kwargs)
# 2. Apply migrations for old keys
if migrations:
for old_key, (new_parent, new_key) in migrations.items():
if old_key in config:
config.setdefault(new_parent, {})
config[new_parent][new_key] = config.pop(old_key)
# 3. Log warnings
warnings = find_unknown_keys(config, defaults)
log_collection(logger=logger, collection=warnings, level=logging.WARNING)
# 4. Deep merge overrides on defaults
result = deep_merge(deepcopy(defaults), config)
# 5. Now `result` conforms to SourceDictFull — tell the type checker
return cast(T, result)