Source code for corr_vars.utils.helpers

from __future__ import annotations

from copy import deepcopy
import inspect
import warnings

import logging
import numpy as np
import polars as pl
import pandas as pd
from pandas.core.groupby import DataFrameGroupBy
from tqdm import tqdm

from corr_vars import logger
import corr_vars.sources as sources
from corr_vars.utils.logging import log_collection

from collections.abc import Callable, Iterable, Mapping, MutableMapping
from corr_vars.definitions import (
    TimeBoundColumn,
    IntoExprColumn,
    PandasDataframe,
    PolarsDataframe,
    PolarsFrame,
)
from typing import Literal, Any, overload, TYPE_CHECKING, TypeVar, cast

if TYPE_CHECKING:
    from corr_vars.core.variable import Variable

T = TypeVar("T", bound=MutableMapping[str, Any])


# Unused.
[docs] def filter_by_condition( df: pl.DataFrame, expression: ( IntoExprColumn | Iterable[IntoExprColumn] | bool | list[bool] | np.ndarray[Any, Any] ), description: str = "filter_by_condition call", verbose: bool = True, mode: Literal["drop", "keep"] = "drop", ) -> pd.DataFrame: """ Drop rows from a DataFrame based on a condition function. Parameters: df (pl.DataFrame): The input DataFrame. expression: Expression(s) that evaluate to a boolean Series. description (str): Description of the condition. mode (str): Whether to drop or keep the rows. Can be "drop" or "keep". Returns: pd.DataFrame: The DataFrame with rows dropped based on the condition. """ mask = df.select(expression) if any(dtype != pl.Boolean for dtype in mask.schema.values()): raise ValueError( "Conditional expression must return a boolean Series/Dataframe." ) if mask.null_count().select(pl.any_horizontal(pl.all().gt(0))).item(): warnings.warn( "Condition expression should return a boolean Series/Dataframe without NA values. Consider the use of `fillna`.", UserWarning, ) if mode == "drop": df2 = df.remove(expression) elif mode == "keep": df2 = df.filter(expression) else: raise ValueError(f"Mode must be 'drop' or 'keep'. Got {mode}") if verbose: true_count = mask.sum().sum_horizontal().item() logger.info( f"{mode.upper()}: {true_count} rows ({true_count/len(df):.2%}) due to {description}" ) return df2
# Unused. Written for pandas.
[docs] def df_find_closest( group: pd.DataFrame, to_col: str, tdelta: str = "0", pm_before: str = "52w", pm_after: str = "52w", ) -> pd.Series[Any] | pd.DataFrame: """Find the closest record to the target time. Args: group (pd.DataFrame): The input dataframe. to_col (str): The column containing the target time. tdelta (str): The time delta. Defaults to "0". pm_before (str): The time range before the target time. Defaults to "52w". pm_after (str): The time range after the target time. Defaults to "52w". Returns: pd.Series[Any] | pd.DataFrame: The closest record. """ tdelta = pd.to_timedelta(tdelta) pm_before = pd.to_timedelta(pm_before) pm_after = pd.to_timedelta(pm_after) target_time = group[to_col] + tdelta group["timediff"] = (group["recordtime"] - target_time).abs() within_range = group[ (group["recordtime"] >= (target_time - pm_before)) & (group["recordtime"] <= (target_time + pm_after)) ] if within_range.empty: return pd.Series([np.nan] * len(group.columns), index=group.columns) closest = within_range.loc[within_range["timediff"].idxmin()] closest.drop(columns=["timediff"], inplace=True) return closest
# Unused. Written for pandas.
[docs] def aggregate_column( group: DataFrameGroupBy, agg_func: str, column: str, params: list[str] = [] ) -> pd.Series[Any]: match agg_func: case "median": return group[column].median() case "mean": return group[column].mean() case "perc": return group[column].quantile(float(params[0]) / 100) case "count": return group.size() case _: raise ValueError(f"Unknown aggregation function: {agg_func}")
@overload def get_cb( df: PolarsDataframe, join_key: str = ..., primary_key: str = ..., tmin_col: str = ..., tmax_col: str = ..., ttarget_col: str | None = ..., ) -> PolarsDataframe: ... @overload def get_cb( df: PandasDataframe, join_key: str = ..., primary_key: str = ..., tmin_col: str = ..., tmax_col: str = ..., ttarget_col: str | None = ..., ) -> PandasDataframe: ...
[docs] def get_cb( df: PandasDataframe | PolarsDataframe, join_key: str = "case_id", primary_key: str = "case_id", tmin_col: str = "tmin", tmax_col: str = "tmax", ttarget_col: str | None = None, ) -> PandasDataframe | PolarsDataframe: """Get the case bounds Dataframe. This is used to apply a time filter to a native dynamic variable. Args: df (pd.DataFrame | pl.DataFrame): The cohort dataframe. join_key (str): The join key column. Defaults to "case_id". primary_key (str): The primary key column. Defaults to "case_id". tmin_col (str): The column containing the tmin. Defaults to "tmin". tmax_col (str): The column containing the tmax. Defaults to "tmax". ttarget_col (str | None): The column containing the target time. Defaults to None. (For !closest) Returns: result: The case bounds Dataframe with the same type as `df`. """ if isinstance(df, pl.DataFrame): # Drop rows with NaN tmin or tmax df = df.filter(pl.col(tmin_col).is_not_null() & pl.col(tmax_col).is_not_null()) # Build columns list columns = [join_key, tmin_col, tmax_col] if ttarget_col: columns.append(ttarget_col) if primary_key != join_key: columns.append(primary_key) # Drop case_tmin and case_tmax columns if they exist df = df.drop(["case_tmin", "case_tmax"], strict=False) # Select columns and rename cb = df.select(columns).rename({tmin_col: "case_tmin", tmax_col: "case_tmax"}) return cb else: # Original pandas implementation # Drop rows with NaN tmin or tmax df = df.dropna(subset=[tmin_col, tmax_col]) # Build columns list and format string columns = [join_key, tmin_col, tmax_col] if ttarget_col: columns.append(ttarget_col) if primary_key != join_key: columns.append(primary_key) cb = df[columns].rename(columns={tmin_col: "case_tmin", tmax_col: "case_tmax"}) return cb
def _get_time_series(obs: pd.DataFrame, col_name: str, tdelta: str = "0") -> pd.Series: """Returns a series with the time since col_name (in datetime format) plus tdelta (in pd.Timedelta format). Args: obs (pd.DataFrame): The observation dataframe. col_name (str): The column name. tdelta (str): The time delta. Defaults to "0". Returns: pd.Series: A series with the time column + tdelta. """ tdelta = pd.to_timedelta(tdelta) return obs[col_name] + tdelta @overload def pl_parse_time_args( obs: PolarsFrame, tmin: TimeBoundColumn | None = ..., tmax: TimeBoundColumn | None = ..., ) -> PolarsFrame: ... @overload def pl_parse_time_args( obs: PandasDataframe, tmin: TimeBoundColumn | None = ..., tmax: TimeBoundColumn | None = ..., ) -> PandasDataframe: ...
[docs] def pl_parse_time_args( obs: PolarsFrame | PandasDataframe, tmin: TimeBoundColumn | None = None, tmax: TimeBoundColumn | None = None, ) -> PolarsFrame | PandasDataframe: """Parse the time arguments to get the time series. Args: obs (pd.DataFrame | pl.DataFrame): The observation dataframe. tmin (TimeBoundColumn | None): The tmin argument. Defaults to None. tmax (TimeBoundColumn | None): The tmax argument. Defaults to None. Returns: result: The observation dataframe with the same type as obs with tmin and tmax columns. """ if isinstance(obs, pl.DataFrame | pl.LazyFrame): # Drop rows where tmin or tmax is in year 9999 if tmin: col = tmin[0] if isinstance(tmin, tuple) else tmin obs = obs.filter(pl.col(col).dt.year() != 9999) if tmax: col = tmax[0] if isinstance(tmax, tuple) else tmax obs = obs.filter(pl.col(col).dt.year() != 9999) if isinstance(tmin, tuple): col_name, tdelta = tmin # Convert string to duration and add to column tmin_expr = pl.col(col_name) + pl.duration( seconds=pd.to_timedelta(tdelta).total_seconds() ) elif tmin: tmin_expr = pl.col(tmin) else: tmin_expr = pl.lit(None) if isinstance(tmax, tuple): col_name, tdelta = tmax # Convert string to duration and add to column tmax_expr = pl.col(col_name) + pl.duration( seconds=pd.to_timedelta(tdelta).total_seconds() ) elif tmax: tmax_expr = pl.col(tmax) else: tmax_expr = pl.lit(None) obs = obs.with_columns([tmin_expr.alias("tmin"), tmax_expr.alias("tmax")]) return obs else: # Drop rows where tmin or tmax is in year 9999 if tmin: col = tmin[0] if isinstance(tmin, tuple) else tmin obs = obs[obs[col].dt.year != 9999] if tmax: col = tmax[0] if isinstance(tmax, tuple) else tmax obs = obs[obs[col].dt.year != 9999] tmin_series: pd.Series | float if isinstance(tmin, tuple): col_name, tdelta = tmin tmin_series = _get_time_series(obs, col_name, tdelta=tdelta) elif tmin: tmin_series = _get_time_series(obs, col_name=tmin) else: tmin_series = np.nan tmax_series: pd.Series | float if isinstance(tmax, tuple): col_name, tdelta = tmax tmax_series = _get_time_series(obs, col_name, tdelta=tdelta) elif tmax: tmax_series = _get_time_series(obs, col_name=tmax) else: tmax_series = np.nan obs["tmin"] = tmin_series obs["tmax"] = tmax_series return obs
# Unused. Written for pandas.
[docs] def extract_df_data( df: pd.DataFrame, col_dict: dict | None = None, filter_dict: dict[str, list[str]] | None = None, exact_match: bool = False, remove_prefix: bool = False, drop: bool = False, ) -> pd.DataFrame: """ Extracts data from a DataFrame. Parameters: df (pandas.DataFrame): The DataFrame to operate on. col_dict (dict, optional): A dictionary mapping column names to new names. Defaults to None. filter_dict (dict[str, list:str], optional): A dictionary where keys are column names and values are lists of values to filter rows by (may include regex pattern for exact_match=False). Defaults to None. exact_match (bool, optional): If True, performs exact matching when filtering. Defaults to False. remove_prefix (bool, optional): If True, removes prefix from default_key. Defaults to False. drop (bool, optional): If True, drop all columns not specified in col_dict. Returns: pandas.DataFrame: A DataFrame containing the extracted data from the original DataFrame. """ # rename columns if applicable if col_dict: df = df.rename(columns=col_dict) if drop: df = df.filter(items=col_dict.values()) # filter rows for given values if filter_dict: for column, values in filter_dict.items(): if exact_match: df = df[df[column].isin(values)] else: if df[column].dtype == "object": regex = "|".join(values) df = df[df[column].str.contains(regex, na=False, regex=True)] else: df = df[df[column].isin(values)] return df
[docs] def merge_consecutive( data: pd.DataFrame, primary_key: str, # Deprecated! recordtime: str = "recordtime", recordtime_end: str = "recordtime_end", time_threshold: pd.Timedelta = pd.Timedelta(minutes=30), ) -> pd.DataFrame: """ Combine consecutive sessions (<30min separation) of ecmo_vv_icu into a single session. Args: data (pd.DataFrame): The data to merge. primary_key (str): The primary key column. Deprecated! Will not be used, case_id will be used instead. recordtime (str): The recordtime column. Defaults to "recordtime". recordtime_end (str): The recordtime_end column. Defaults to "recordtime_end". time_threshold (pd.Timedelta): The time threshold. Defaults to 30 minutes. Returns: pd.DataFrame: The merged data. """ # Sort by primary_key and recordtime data = data.sort_values(by=["case_id", recordtime]).reset_index(drop=True) def combine_consecutive_sessions(group): group["last_event_recordtime_end"] = group["recordtime_end"].shift(1) group["time_since_last_event"] = ( group["recordtime"] - group["last_event_recordtime_end"] ) group["merge_with_previous"] = group["time_since_last_event"] < time_threshold group["merge_group"] = (~group["merge_with_previous"]).cumsum() gr_merged = ( group.groupby(by=["case_id", "merge_group"]) .agg({"value": "first", "recordtime": "first", "recordtime_end": "last"}) .reset_index() ) return gr_merged grouped = data.groupby("case_id") grouped = grouped.apply(combine_consecutive_sessions) return grouped.reset_index(drop=True)
[docs] def interval_bucket_agg( data: PolarsFrame, primary_key: str = "case_id", recordtime: str = "recordtime", recordtime_end: str | None = None, recordtime_end_gap_treshold: str | None = None, recordtime_end_gap_behavior: Literal["impute", "drop"] = "drop", duration_empty_impute_duration: str | None = None, value: str = "value", tmin: str | None = None, tmax: str | None = None, t0: str | None = None, t_missing_behavior: Literal["keep", "drop"] = "drop", granularity: str | None = "1h", same_bin_aggregation: Callable[[pl.Expr], pl.Expr] | None = None, ) -> PolarsFrame: """ Bins intervals clipped between tmin and tmax into bins aligned to t0 and aggregates on these bins. Args: data (pl.DataFrame | pl.LazyFrame): The data to transform. primary_key (str): The primary key column. Defaults to "case_id". recordtime (str): The recordtime column. Defaults to "recordtime". recordtime_end (str | None): The recordtime_end column. Defaults to `None`. Will impute recordtime_end from the next recordtime if set to `None`. recordtime_end_gap_treshold (str | None): Maximum allowed gap for recordtime_end. Only applies if recordtime_end is set to `None`. recordtime_end_gap_behavior (str | None): Determines whether to impute or drop rows when the `recordtime_end_gap_treshold` is exceeded. duration_empty_impute_duration (str | None): The duration to be imputed by offsetting recordtime_end. Drops rows if set to default of `None` or ≤0s. value (str): The value column. Defaults to "value". tmin (str | None): The tmin column. Defaults to `None`. tmax (str | None): The tmax column. Defaults to `None`. t0 (str | None): The t0 column. Bins will align to this datetime col if specified. Defaults to `None`. t_missing_behavior (str | None): Determines whether to drop or keep rows when `t0`, `tmin` or `tmax` columns are specified and values are missing in `data`. Defaults to "drop". granularity (str | None): The size of the bucket intervals. Defaults to 1 hour. same_bin_aggregation: (Callable[[pl.Expr], pl.Expr] | None): The aggregation to perform on `value` for each bin. Uses `pl.sum()` if set to `None`. .. Note:: The `granularity`, `recordtime_end_impute_gap_treshold` and `duration_empty_impute_duration` arguments are created with the following string language: - 1ns (1 nanosecond) # nanoseconds are not supported by our DataFrames. - 1us (1 microsecond) - 1ms (1 millisecond) - 1s (1 second) - 1m (1 minute) - 1h (1 hour) - 1d (1 calendar day) - 1w (1 calendar week) - 1mo (1 calendar month) - 1q (1 calendar quarter) - 1y (1 calendar year) Or combine them: "3d12h4m25s" # 3 days, 12 hours, 4 minutes, and 25 seconds By "calendar day", we mean the corresponding time on the next day (which may not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". If you do not wish to use "calendar day" or any other "calendar" durations, specify the duration in hours or smaller units instead. Returns: clip_bin_data (pl.DataFrame | pl.LazyFrame): The clipped and binned data. Same type as `data`. """ lazy_input = data if isinstance(data, pl.LazyFrame) else data.lazy() START = "start_time" END = "end_time" DURATION = "duration_time" VALUE = "value" TMIN = "tmin" TMAX = "tmax" T0 = "t0" START_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{START}" END_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{END}" DURATION_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{DURATION}" VALUE_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{VALUE}" START_SPLIT = f"split_{START}" END_SPLIT = f"split_{END}" START_OVERLAP = f"overlap_{START}" END_OVERLAP = f"overlap_{END}" DURATION_OVERLAP = f"overlap_{DURATION}" VALUE_OVERLAP = f"overlap_{VALUE}" BEFORE_OR_AFTER_T0_SIGNED = "t0_side_signed" T0_GRANULARITY_DISTANCE = "t0_granularity_distance" AGG_START_OVERLAP = f"first_{START_OVERLAP}" AGG_END_OVERLAP = f"last_{END_OVERLAP}" AGG_DURATION_OVERLAP_S = f"max_{DURATION_OVERLAP}_s" AGG_VALUE = f"agg_{VALUE}" ATTRIBUTES = "attributes" # Select required columns from input and rename them # Fill recordtime_end, tmin, tmax and t0 with None if unspecified lf = lazy_input.select( primary_key, pl.col(recordtime).alias(START), ( pl.lit(None).alias(END) if recordtime_end is None else pl.col(recordtime_end).alias(END) ), pl.col(value).alias(VALUE), pl.lit(None).alias(TMIN) if tmin is None else pl.col(tmin).alias(TMIN), pl.lit(None).alias(TMAX) if tmax is None else pl.col(tmax).alias(TMAX), pl.lit(None).alias(T0) if t0 is None else pl.col(t0).alias(T0), ) # impute recordtime_end if necessary if recordtime_end is None: # Sort and calculate end_time by shifting over partition lf = lf.sort(primary_key, START).with_columns( # NOTE: Create nulls for the last entry of each partition # Will be dropped in the next step pl.col(START) .shift(-1) .over(primary_key) .alias(END) ) # Enforce max gap treshold if recordtime_end_gap_treshold: if recordtime_end_gap_behavior == "drop": lf = lf.remove( pl.col(START) .dt.offset_by(recordtime_end_gap_treshold) .lt(pl.col(END)) ) else: lf = lf.with_columns( pl.min_horizontal( pl.col(START).dt.offset_by(recordtime_end_gap_treshold), pl.col(END), ).alias(END) ) # Drop nulls and calculate duration lf = lf.drop_nulls( [primary_key, START, END, VALUE] + ([TMIN] if tmin is not None and t_missing_behavior == "drop" else []) + ([TMAX] if tmax is not None and t_missing_behavior == "drop" else []) + ([T0] if t0 is not None and t_missing_behavior == "drop" else []) ).with_columns(pl.col(END).sub(pl.col(START)).alias(DURATION)) if duration_empty_impute_duration: lf = lf.with_columns( pl.when(pl.col(DURATION).eq(0)) .then(pl.col(END).dt.offset_by(duration_empty_impute_duration)) .otherwise(pl.col(END)) .alias(END) ) # Filter out empty intervals and intervals not intersecting [tmin, tmax] lf_clipped = ( lf.filter( pl.col(DURATION).gt(0) & pl.col(START).lt(pl.col(TMAX)).or_(pl.col(TMAX).is_null()) & pl.col(END).gt(pl.col(TMIN)).or_(pl.col(TMIN).is_null()) ) .with_columns( pl.col(START).clip(TMIN, TMAX).alias(START_CLIPPED_TMIN_TMAX), pl.col(END).clip(TMIN, TMAX).alias(END_CLIPPED_TMIN_TMAX), ) .drop(TMIN, TMAX) .with_columns( pl.col(END_CLIPPED_TMIN_TMAX) .sub(pl.col(START_CLIPPED_TMIN_TMAX)) .alias(DURATION_CLIPPED_TMIN_TMAX) ) ) # Calculate corrected values for intervals clipped to [tmin, tmax] # Skip this step when tmin or tmax unspecified for performance # It is practically a no-op in that case corrected_value_expr: pl.Expr if tmin is not None or tmax is not None: corrected_value_expr = ( pl.col(VALUE) .cast(pl.Float64) .mul(pl.col(DURATION_CLIPPED_TMIN_TMAX)) .truediv(pl.col(DURATION)) ) else: corrected_value_expr = pl.col(VALUE).cast(pl.Float64) lf_clipped = ( lf_clipped # Shouldn't be possible after the previous filter step .filter(pl.col(DURATION_CLIPPED_TMIN_TMAX).gt(0)).with_columns( corrected_value_expr.alias(VALUE_CLIPPED_TMIN_TMAX) ) ) # NOTE: Empty string is also not valid correction_expr: pl.Expr datetime_ranges_expr: pl.Expr intersection_condition_expr: pl.Expr lf_distributed: pl.LazyFrame if granularity: # Use t0 as relative anchor for the buckets if t0 is not None: # Apply correction offset to align bins to a grid starting from t0 correction_expr = pl.col(T0).sub(pl.col(T0).dt.truncate(granularity)) datetime_ranges_expr = pl.datetime_ranges( pl.col(START_CLIPPED_TMIN_TMAX) .sub(correction_expr) .dt.truncate(granularity), pl.col(END_CLIPPED_TMIN_TMAX).sub(correction_expr), granularity, ) lf_distributed = ( # Generate buckets lf_clipped.with_columns( datetime_ranges_expr.alias(START_SPLIT) ).explode(START_SPLIT) # Revert correction offset .with_columns( pl.col(START_SPLIT).add(correction_expr), ) ) # Snap buckets to previous grid point else: datetime_ranges_expr = pl.datetime_ranges( pl.col(START_CLIPPED_TMIN_TMAX).dt.truncate(granularity), END_CLIPPED_TMIN_TMAX, granularity, ) lf_distributed = ( # Generate buckets lf_clipped.with_columns( datetime_ranges_expr.alias(START_SPLIT) ).explode(START_SPLIT) ) lf_distributed = ( lf_distributed # Calculate end_granularity .with_columns( pl.col(START_SPLIT).dt.offset_by(granularity).alias(END_SPLIT), ) ) else: # Split intervals containing t0 into two intervals if t0 is not None: # Find intervals containing t0 intersection_condition_expr = ( pl.col(START_CLIPPED_TMIN_TMAX) .lt(pl.col(T0)) .and_(pl.col(END_CLIPPED_TMIN_TMAX).gt(pl.col(T0))) ) # Split intervals containing t0 lf_distributed = lf_clipped.with_columns( pl.when(intersection_condition_expr) .then( pl.concat_list( pl.col(START_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime)), pl.col(T0).cast(pl.List(pl.Datetime)), ) ) .otherwise(pl.col(START_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime))) .alias(START_SPLIT), pl.when(intersection_condition_expr) .then( pl.concat_list( pl.col(T0).cast(pl.List(pl.Datetime)), pl.col(END_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime)), ) ) .otherwise(pl.col(END_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime))) .alias(END_SPLIT), ).explode(START_SPLIT, END_SPLIT) # Do nothing: Split times are equal to previous times and not exploded else: # Use clipped times as bucket times lf_distributed = ( # Generate buckets lf_clipped.with_columns( pl.col(START_CLIPPED_TMIN_TMAX).alias(START_SPLIT), pl.col(END_CLIPPED_TMIN_TMAX).alias(END_SPLIT), ) ) # Calculate corrected values for intervals clipped to [tmin, tmax] # Skip this step when tmin or tmax unspecified for performance # It is practically a no-op in that case if granularity or t0 is not None: corrected_value_expr = ( pl.col(VALUE_CLIPPED_TMIN_TMAX) # Already cast to Float64 .mul(pl.col(DURATION_OVERLAP)) .truediv(pl.col(DURATION_CLIPPED_TMIN_TMAX)) ) else: corrected_value_expr = pl.col(VALUE_CLIPPED_TMIN_TMAX) lf_distributed = ( lf_distributed.with_columns( pl.max_horizontal(START_CLIPPED_TMIN_TMAX, START_SPLIT).alias( START_OVERLAP ), pl.min_horizontal(END_CLIPPED_TMIN_TMAX, END_SPLIT).alias(END_OVERLAP), ) .with_columns( pl.col(END_OVERLAP).sub(pl.col(START_OVERLAP)).alias(DURATION_OVERLAP) ) .with_columns(corrected_value_expr.alias(VALUE_OVERLAP)) .select( primary_key, START_SPLIT, END_SPLIT, T0, START_OVERLAP, END_OVERLAP, VALUE_OVERLAP, ) ) # Importantly this data is not yet aggregated on the intervals... # There can be multiple entries per interval # Aggregate interval def agg_interval( df: pl.LazyFrame, more_group_by_cols: Iterable[pl.Expr | str] = [] ): return ( df.group_by(primary_key, *more_group_by_cols) .agg( pl.min(START_OVERLAP).alias(AGG_START_OVERLAP), pl.max(END_OVERLAP).alias(AGG_END_OVERLAP), ( pl.col(VALUE_OVERLAP).pipe(same_bin_aggregation).alias(AGG_VALUE) if same_bin_aggregation else pl.col(VALUE_OVERLAP).sum().alias(AGG_VALUE) ), ) # Recalculate overlap_duration_s with new aggregated timestamps .with_columns( pl.col(AGG_END_OVERLAP) .sub(pl.col(AGG_START_OVERLAP)) .dt.total_seconds() .alias(AGG_DURATION_OVERLAP_S) ) .sort(primary_key, AGG_START_OVERLAP) ) group_cols: list[str | pl.Expr] = [] if granularity: group_cols += [START_SPLIT, END_SPLIT] + ([T0] if t0 is not None else []) elif t0 is not None: group_cols += [ pl.when(pl.col(START_SPLIT).lt(pl.col(T0))) .then(pl.lit(-1)) .otherwise(pl.lit(1)) .alias(BEFORE_OR_AFTER_T0_SIGNED), ] agg_df = lf_distributed.pipe(agg_interval, group_cols) extra_attributes: list[pl.Expr | str] = [ AGG_START_OVERLAP, AGG_END_OVERLAP, AGG_DURATION_OVERLAP_S, ] if granularity and t0 is not None: extra_attributes += [ pl.col(START_SPLIT) .sub(pl.col(T0)) .truediv(pl.col(END_SPLIT).sub(pl.col(START_SPLIT))) .alias(T0_GRANULARITY_DISTANCE) ] elif t0 is not None: extra_attributes += [BEFORE_OR_AFTER_T0_SIGNED] agg_df_packed = agg_df.with_columns( pl.struct(extra_attributes).alias(ATTRIBUTES) ).select(primary_key, *group_cols, AGG_VALUE, ATTRIBUTES) return agg_df_packed if isinstance(data, pl.LazyFrame) else agg_df_packed.collect()
@overload def clean_timeseries( data: PolarsDataframe, max_roc: float, time_col: str = ..., value_col: str = ..., identifier_col: str = ..., ) -> PolarsDataframe: ... @overload def clean_timeseries( data: PandasDataframe, max_roc: float, time_col: str = ..., value_col: str = ..., identifier_col: str = ..., ) -> PandasDataframe: ...
[docs] def clean_timeseries( data: PandasDataframe | PolarsDataframe, max_roc: float, time_col: str = "recordtime", value_col: str = "value", identifier_col: str = "icu_stay_id", ) -> PandasDataframe | PolarsDataframe: """ Clean timeseries by removing values that create impossible rates of change using vectorized operations with Polars for maximum performance. Args: data (pl.DataFrame): The data to clean. max_roc (float): The maximum rate of change. time_col (str): The time column. value_col (str): The value column. identifier_col (str): The identifier column. Returns: result: The cleaned data with the same type as `data`. """ return_pandas = False # Convert to Polars if necessary if isinstance(data, pd.DataFrame): data = pl.from_pandas(data) return_pandas = True def clean_group(group): if len(group) <= 1: return group.with_columns(pl.lit(None).alias("current_roc")) group = group.sort(time_col) times = group.select(time_col).to_numpy().flatten() values = group.select(value_col).to_numpy().flatten() keep_mask = np.ones(len(group), dtype=bool) i = 1 while i < len(group): last_kept_idx = np.where(keep_mask[:i])[0][-1] time_diff = ( times[i] - times[last_kept_idx] ) / 3600 # Convert seconds to hours value_diff = values[i] - values[last_kept_idx] current_roc = abs(value_diff / time_diff) if time_diff > 0 else float("inf") if np.isnan(current_roc) or current_roc >= max_roc: keep_mask[i] = False i += 1 filtered_group = group.filter(pl.lit(keep_mask)) # Calculate RoC for kept values if len(filtered_group) <= 1: return filtered_group.with_columns(pl.lit(None).alias("current_roc")) return filtered_group.with_columns( ( (pl.col("value") - pl.col("value").shift(1)) / ((pl.col(time_col) - pl.col(time_col).shift(1)) / 3600) ).alias("current_roc") ) result = data.partition_by(identifier_col) cleaned_groups = [ clean_group(group) for group in tqdm(result, desc="Cleaning timeseries") ] res = ( pl.concat(cleaned_groups, how="vertical_relaxed") if cleaned_groups else data.clear() ) return res.to_pandas() if return_pandas else res
[docs] def harmonize_str_list_cols(dfs: Iterable[pl.DataFrame]) -> list[pl.DataFrame]: cols = set().union(*(df.columns for df in dfs)) if not cols: raise ValueError("All dataframes are empty.") harmonized = [] for col in cols: types = {df.schema[col] for df in dfs if col in df.columns} if pl.String in types and pl.List(pl.String) in types: harmonized = [ df.with_columns( pl.col(col).cast(pl.List(pl.String), strict=False).alias(col) if col in df.columns else [] ) for df in dfs ] return harmonized or list(dfs)
[docs] def guess_variable_source(variable: Variable) -> str | None: """Guess the source of a variable.""" variable_module = inspect.getmodule(variable) if variable_module is None: return None prefix = f"{sources.__name__}." variable_module_name = variable_module.__name__ if not variable_module_name.startswith(prefix): return None source = variable_module_name.removeprefix(prefix).split(".", maxsplit=1)[0] return source
[docs] def find_unknown_keys( user_cfg: Mapping[str, Any], schema: Mapping[str, Any], prefix: str = "", warnings: list[str] | None = None, ) -> list[str]: if warnings is None: warnings = [] for key, value in user_cfg.items(): path = f"{prefix}.{key}" if prefix else key if key not in schema: warnings.append(f"Unknown key in config: {path}") continue if isinstance(value, Mapping) and isinstance(schema.get(key), Mapping): find_unknown_keys(value, schema[key], path, warnings) return warnings
[docs] def deep_merge(base: T, override: Mapping[str, Any]) -> T: """ Recursively merge `override` into `base` with override taking precedence. Works on normal dictionaries, compatible with TypedDict runtime dicts. """ for key, value in override.items(): if ( key in base and isinstance(base[key], MutableMapping) and isinstance(value, Mapping) ): deep_merge(base[key], value) else: base[key] = value return base
[docs] def load_source_defaults( source_kwargs: Mapping[str, Any], defaults: T, migrations: Mapping[str, Any] | None = None, ) -> T: # 1. Load source config config = dict(source_kwargs) # 2. Apply migrations for old keys if migrations: for old_key, (new_parent, new_key) in migrations.items(): if old_key in config: config.setdefault(new_parent, {}) config[new_parent][new_key] = config.pop(old_key) # 3. Log warnings warnings = find_unknown_keys(config, defaults) log_collection(logger=logger, collection=warnings, level=logging.WARNING) # 4. Deep merge overrides on defaults result = deep_merge(deepcopy(defaults), config) # 5. Now `result` conforms to SourceDictFull — tell the type checker return cast(T, result)