from io import BytesIO
from os import PathLike
import pandas as pd
import polars as pl
import polars.selectors as cs
from tableone import TableOne
from corr_vars import logger, __
from collections.abc import Callable, Hashable, Iterable
from corr_vars.definitions import (
PolarsFrame,
TimeAnchorColumn,
StataDateFormat,
ObsLevel,
)
from polars._typing import AsofJoinStrategy
from typing import Literal, overload, TYPE_CHECKING
from corr_vars.definitions.typing import CleaningDict
if TYPE_CHECKING:
from corr_vars.utils.time import TimeAnchor
# ------------------------
# Converter
# ------------------------
[docs]
def convert_to_polars_df(
value: pl.DataFrame | pl.LazyFrame | pd.DataFrame,
) -> pl.DataFrame:
if isinstance(value, pl.DataFrame):
return value
elif isinstance(value, pl.LazyFrame):
return value.collect()
elif isinstance(value, pd.DataFrame):
return pl.from_pandas(value)
else:
raise TypeError(f"Conversion failed. Unsupported data type: {type(value)}")
[docs]
def convert_to_polars_lf(
value: pl.DataFrame | pl.LazyFrame | pd.DataFrame,
) -> pl.LazyFrame:
if isinstance(value, pl.DataFrame):
return value.lazy()
elif isinstance(value, pl.LazyFrame):
return value
elif isinstance(value, pd.DataFrame):
return pl.from_pandas(value).lazy()
else:
raise TypeError(f"Conversion failed. Unsupported data type: {type(value)}")
[docs]
def convert_to_pandas_df(
value: pl.DataFrame | pl.LazyFrame | pd.DataFrame,
) -> pd.DataFrame:
if isinstance(value, pl.DataFrame):
return value.to_pandas()
elif isinstance(value, pl.LazyFrame):
return value.collect().to_pandas()
elif isinstance(value, pd.DataFrame):
return value
else:
raise TypeError(f"Conversion failed. Unsupported data type: {type(value)}")
[docs]
def convert_to_stata(
df: pd.DataFrame,
convert_dates: (
dict[
Hashable,
StataDateFormat,
]
| None
) = None,
write_index: bool = True,
to_file: str | PathLike[str] | None = None,
) -> pd.DataFrame | None:
"""
Convert a pandas DataFrame to a Stata compatible DataFrame or writes the DataFrame to a Stata file.
Args:
df (pd.DataFrame): The DataFrame to be converted to Stata format.
convert_dates (dict[Hashable, str]): Dictionary of columns to convert to Stata date format.
write_index (bool): Whether to write the index as a column.
to_file (str | None): Path to save as .dta file. If left unspecified, the DataFrame will not be saved.
Returns:
pd.DataFrame: A Pandas Dataframe compatible with Stata if `to_file` is None.
"""
df[df.select_dtypes(include=["object"]).columns] = df.select_dtypes(
include=["object"]
).astype(str)
format_guidance = []
for col in df.select_dtypes(include=["datetime64"]).columns:
format_guidance.append(f"format {col} %tc")
df.loc[:, col] = df[col].mask(
(df[col] < pd.Timestamp.min) | (df[col] > pd.Timestamp.max),
pd.NA,
)
format_guidance_string = "\n".join(format_guidance)
max_string_length = bool(
df.select_dtypes(include=["object", "string"]).map(len).max(axis=None).item()
)
version = 114 if max_string_length < 245 else None
logger.info(
__(
"Creating a DataFrame compatible with {version} based on max string length {max_string_length}",
version=version if version else "118 or 119",
max_string_length=max_string_length,
)
)
logger.info(
__(
"You can load the data with a command like this inside a Notebook:\n"
"%%stata -d stata_df -force\n{format_guidance_string}",
format_guidance_string=format_guidance_string,
)
)
if to_file:
df.to_stata(
to_file,
convert_dates=convert_dates,
write_index=write_index,
version=version,
)
return None
else:
output = BytesIO()
df.to_stata(
output,
convert_dates=convert_dates,
write_index=write_index,
version=version,
)
output.seek(0)
return pd.read_stata(
output,
convert_dates=False,
convert_missing=False,
convert_categoricals=False,
)
[docs]
def convert_to_tableone(
df: pl.DataFrame,
ignore_cols: list[str] | str | None = None,
filter: str | None = None,
replace_booleans: tuple[str, str] | None = ("Yes", "No"),
# TableOne kwargs
display_all: bool = True,
groupby: str | None = None,
normal_cols: list[str] | None = None,
overall: bool | None = None,
order: dict[str, list[str]] | None = None,
pval: bool = False,
**kwargs,
) -> TableOne:
"""
Create a `TableOne <https://tableone.readthedocs.io/en/latest/index.html>`_ object for the pandas DataFrame.
Args:
df (pl.DataFrame): The input.
ignore_cols (list | str | None): Column(s) to ignore.
filter (str | None): Filter to apply to the data.
replace_booleans (tuple[str, str] | None): Replace booleans with the given strings.
display_all (bool): Whether to display all columns.
groupby (str | None): Column to group by.
normal_cols (list[str] | None): Columns to treat as normally distributed.
overall (bool): Whether to add an “overall” column to the table. If left unspecified the overall column will be dropped if groupby is specified.
order (dict[str, list[str]] | None): Order of categorical columns.
pval (bool): Whether to calculate p-values.
**kwargs: Additional arguments to pass to TableOne.
Returns:
TableOne: A TableOne object.
"""
# Ignore userspecified columns
exclude_cols: list[str] = []
if ignore_cols is not None:
if isinstance(ignore_cols, str):
exclude_cols.append(ignore_cols)
else:
exclude_cols.extend(ignore_cols)
# Also ignore datetime, list/array, dict columns, etc.
unhashable_cols: list[str] = df.select(
~(
cs.boolean()
| cs.categorical()
| cs.numeric()
| cs.string()
| cs.by_dtype(pl.Enum)
)
).columns
logger.info(
__(
"Skipping unhashable columns: {unhashable_cols}",
unhashable_cols=lambda: ",".join(unhashable_cols),
)
)
exclude_cols.extend(unhashable_cols)
# Ignore ID columns
id_columns: list[str] = [obs_level.primary_key for obs_level in ObsLevel]
exclude_cols.extend(id_columns)
# Ignore groupby column (added as special argument)
if groupby is not None:
exclude_cols.append(groupby)
# Collect included columns
include_cols: list[str] = df.select(pl.all().exclude(exclude_cols)).columns
# Reorder columns to put icu_id last (because it is so long...)
if "icu_id" in include_cols:
include_cols.append(include_cols.pop(include_cols.index("icu_id")))
# Set non-normal distribution for all numeric columns, except for user specififed normal columns
normal: list[str] = normal_cols or []
nonnormal: list[str] = (
df.select(pl.col(include_cols))
.select(cs.numeric().exclude("inhospital_death"))
.columns
)
# Explicitly set booleans and strings to categorical
categorical_order = order or {}
categorical: list[str] = (
df.select(pl.col(include_cols))
.select(cs.boolean() | cs.categorical() | cs.string() | cs.by_dtype(pl.Enum))
.columns
)
# Transform / filter data
boolean: list[str] = df.select(pl.col(include_cols)).select(cs.boolean()).columns
if replace_booleans is not None:
df = df.with_columns(
pl.col(boolean).replace_strict(
{
True: replace_booleans[0],
False: replace_booleans[1],
},
return_dtype=pl.Utf8,
)
)
for b in boolean:
# Only override order if unspecified
categorical_order.setdefault(b, [replace_booleans[0], replace_booleans[1]])
pd_df = convert_to_pandas_df(df)
if filter:
pd_df = pd_df.query(filter)
else:
pd_df = pd_df
logger.info(
__(
"Columns: {include_cols}",
include_cols=lambda: ",".join(include_cols),
)
)
logger.debug(
__(
"Unhashable Columns: {unhashable_cols}",
unhashable_cols=lambda: ",".join(unhashable_cols),
)
)
logger.debug(
__(
"Excluded Columns: {exclude_cols}",
exclude_cols=lambda: ",".join(exclude_cols),
)
)
logger.debug(
__(
"Included Columns: {include_cols}",
include_cols=lambda: ",".join(include_cols),
)
)
logger.info(
__(
"Normal: {normal}",
normal=lambda: ",".join(normal),
)
)
logger.info(
__(
"Nonnormal: {nonnormal}",
nonnormal=lambda: ",".join(nonnormal),
)
)
logger.info(
__(
"Categorical: {categorical}",
categorical=lambda: ",".join(categorical),
)
)
logger.info(
__(
"Groupby: {groupby}",
groupby=groupby,
)
)
# Overall
add_overall = overall or (groupby is None)
return TableOne(
pd_df,
columns=include_cols,
nonnormal=nonnormal,
categorical=categorical,
# Pass through kwargs
display_all=display_all,
groupby=groupby,
order=categorical_order,
overall=add_overall,
pval=pval,
**kwargs,
)
# ------------------------
# MutliSourceVariable Pipes
# ------------------------
# ------------------------
# Quality of life improvements
# ------------------------
[docs]
def absolute_and_relative_value_counts(
input: pl.LazyFrame | pl.DataFrame,
cols: list[str] | str,
sort: Literal["cols", "counts"] | None = "counts",
decimals: int = 2,
) -> pl.DataFrame:
"""
Returns the absolute and relative value_counts for the column(s) `cols` of a `pl.LazyFrame` or `pl.DataFrame`.
"""
df = input.lazy() if isinstance(input, pl.DataFrame) else input
value_counts = (
df.group_by(cols)
.agg(
pl.len().alias("absolute (int)"),
(100 * pl.len() / df.select(pl.len()).collect().item())
.round(decimals)
.alias("relative (%)"),
)
.collect()
)
if sort == "cols":
value_counts = value_counts.sort(cols)
elif sort == "counts":
value_counts = value_counts.sort("absolute (int)", descending=True)
return value_counts
# ------------------------
# Expression pipes
# ------------------------
[docs]
def as_expr(value: pl.Expr | str) -> pl.Expr:
if isinstance(value, pl.Expr):
return value
return pl.col(value)
[docs]
def unique_sucessive(value: pl.Expr | str) -> pl.Expr:
"""Drop sucessive duplicate values of a list column."""
return as_expr(value).list.filter(
pl.element().ne(pl.element().shift(1)) | pl.element().shift(1).is_null()
)
[docs]
def time_difference(
time_col: pl.Expr | str,
reference_col: pl.Expr | str,
*,
unit: Literal["s", "m", "h", "d", "w"] = "s",
total: bool = False,
) -> pl.Expr:
"""Calculates the time difference between two columns."""
time = as_expr(time_col)
reference = as_expr(reference_col)
expr = time.sub(reference).dt.total_seconds()
if unit == "s":
return expr
conversions = {
"m": 60,
"h": 60 * 60,
"d": 24 * 60 * 60,
"w": 7 * 24 * 60 * 60,
}
if unit not in conversions:
raise ValueError(f"Invalid unit: {unit}")
def op(e: pl.Expr, v: int) -> pl.Expr:
return e.truediv(v) if total else e.floordiv(v)
return op(expr, conversions[unit])
# ------------------------
# DataFrame / LazyFrame pipes
# ------------------------
@overload
def remove_asof(
main: pl.DataFrame,
ref: pl.DataFrame,
strategy: AsofJoinStrategy = ...,
by: str = ...,
on: str = ...,
tolerance: str = ...,
) -> pl.DataFrame:
...
@overload
def remove_asof(
main: pl.LazyFrame,
ref: pl.LazyFrame,
strategy: AsofJoinStrategy = ...,
by: str = ...,
on: str = ...,
tolerance: str = ...,
) -> pl.LazyFrame:
...
[docs]
def remove_asof(
main: PolarsFrame,
ref: PolarsFrame,
strategy: AsofJoinStrategy = "nearest",
by: str = "case_id",
on: str = "recordtime",
tolerance: str = "1d",
) -> PolarsFrame:
"""Removes entries from ref to main with choosen search strategy and tolerance."""
if type(main) is not type(ref):
raise TypeError("Types of main and ref are not the same.")
reverted_strategy_mapping: dict[AsofJoinStrategy, AsofJoinStrategy] = {
"backward": "forward",
"forward": "backward",
"nearest": "nearest",
}
reverted_strategy = reverted_strategy_mapping[strategy]
return main.join(
ref.sort(by, on)
.select(by, on)
.join_asof(
main.sort(by, on).select(by, on), # type: ignore
on=on,
suffix="_main",
by=by,
strategy=reverted_strategy,
tolerance=tolerance,
coalesce=False,
)
.drop(on)
.rename({f"{on}_main": on}),
on=[by, on],
how="anti",
)
[docs]
def find_nearest(
df: pl.DataFrame,
time: TimeAnchorColumn,
pm_before: str = "52w",
pm_after: str = "52w",
recordtime: str = "recordtime",
) -> pl.DataFrame:
"""Find the closest record to the target time.
Args:
df (pl.DataFrame): The input dataframe.
to_col (str): The column containing the target time.
tdelta (str): The time delta. Defaults to "0".
pm_before (str): The time range before the target time. Defaults to "52w".
pm_after (str): The time range after the target time. Defaults to "52w".
Returns:
pl.DataFrame: The closest record.
"""
target_time_anchor = TimeAnchor(time)
time_range = pl.col(recordtime).sub(target_time_anchor.expr).abs()
within_range = pl.col(recordtime).is_between(
target_time_anchor.expr.dt.offset_by(f"-{pm_before}"),
target_time_anchor.expr.dt.offset_by(pm_after),
)
filtered_df = df.filter(within_range)
if filtered_df.is_empty():
return filtered_df
closest = filtered_df.sort(time_range).head(1)
return closest
[docs]
def apply_cleaning(df: pl.DataFrame, cleaning: CleaningDict) -> pl.DataFrame:
"""Apply cleaning rules to filter out invalid values.
Returns:
cleaned (pl.DataFrame): Cleaned DataFrame
"""
if not cleaning or df.is_empty():
return df
for col, bounds in cleaning.items():
if "low" in bounds:
df = df.filter(pl.col(col).ge(bounds["low"]))
if "high" in bounds:
df = df.filter(pl.col(col).le(bounds["high"]))
return df
[docs]
def interval_bucket_agg(
data: PolarsFrame,
primary_key: str = "case_id",
recordtime: str = "recordtime",
recordtime_end: str | None = None,
recordtime_end_gap_treshold: str | None = None,
recordtime_end_gap_behavior: Literal["impute", "drop"] = "drop",
duration_empty_impute_duration: str | None = None,
value: str = "value",
tmin: str | None = None,
tmax: str | None = None,
t0: str | None = None,
t_missing_behavior: Literal["keep", "drop"] = "drop",
granularity: str | None = "1h",
same_bin_aggregation: Callable[[pl.Expr], pl.Expr] | None = None,
) -> PolarsFrame:
"""
Bins intervals clipped between tmin and tmax into bins aligned to t0 and aggregates on these bins.
Args:
data (pl.DataFrame | pl.LazyFrame): The data to transform.
primary_key (str): The primary key column. Defaults to "case_id".
recordtime (str): The recordtime column. Defaults to "recordtime".
recordtime_end (str | None): The recordtime_end column. Defaults to `None`. Will impute recordtime_end from the next recordtime if set to `None`.
recordtime_end_gap_treshold (str | None): Maximum allowed gap for recordtime_end. Only applies if recordtime_end is set to `None`.
recordtime_end_gap_behavior (str | None): Determines whether to impute or drop rows when the `recordtime_end_gap_treshold` is exceeded.
duration_empty_impute_duration (str | None): The duration to be imputed by offsetting recordtime_end. Drops rows if set to default of `None` or ≤0s.
value (str): The value column. Defaults to "value".
tmin (str | None): The tmin column. Defaults to `None`.
tmax (str | None): The tmax column. Defaults to `None`.
t0 (str | None): The t0 column. Bins will align to this datetime col if specified. Defaults to `None`.
t_missing_behavior (str | None): Determines whether to drop or keep rows when `t0`, `tmin` or `tmax` columns are specified and values are missing in `data`. Defaults to "drop".
granularity (str | None): The size of the bucket intervals. Defaults to 1 hour.
same_bin_aggregation: (Callable[[pl.Expr], pl.Expr] | None): The aggregation to perform on `value` for each bin. Uses `pl.sum()` if set to `None`.
.. Note::
The `granularity`, `recordtime_end_impute_gap_treshold` and `duration_empty_impute_duration` arguments are created with
the following string language:
- 1ns (1 nanosecond) # nanoseconds are not supported by our DataFrames.
- 1us (1 microsecond)
- 1ms (1 millisecond)
- 1s (1 second)
- 1m (1 minute)
- 1h (1 hour)
- 1d (1 calendar day)
- 1w (1 calendar week)
- 1mo (1 calendar month)
- 1q (1 calendar quarter)
- 1y (1 calendar year)
Or combine them: "3d12h4m25s" # 3 days, 12 hours, 4 minutes, and 25 seconds
By "calendar day", we mean the corresponding time on the next day (which may
not be 24 hours, due to daylight savings). Similarly for "calendar week",
"calendar month", "calendar quarter", and "calendar year".
If you do not wish to use "calendar day" or any other "calendar" durations, specify the duration in hours or smaller units instead.
Returns:
clip_bin_data (pl.DataFrame | pl.LazyFrame): The clipped and binned data. Same type as `data`.
"""
lazy_input = data if isinstance(data, pl.LazyFrame) else data.lazy()
START = "start_time"
END = "end_time"
DURATION = "duration_time"
VALUE = "value"
TMIN = "tmin"
TMAX = "tmax"
T0 = "t0"
START_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{START}"
END_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{END}"
DURATION_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{DURATION}"
VALUE_CLIPPED_TMIN_TMAX = f"clipped_tmin_tmax_{VALUE}"
START_SPLIT = f"split_{START}"
END_SPLIT = f"split_{END}"
START_OVERLAP = f"overlap_{START}"
END_OVERLAP = f"overlap_{END}"
DURATION_OVERLAP = f"overlap_{DURATION}"
VALUE_OVERLAP = f"overlap_{VALUE}"
BEFORE_OR_AFTER_T0_SIGNED = "t0_side_signed"
T0_GRANULARITY_DISTANCE = "t0_granularity_distance"
AGG_START_OVERLAP = f"first_{START_OVERLAP}"
AGG_END_OVERLAP = f"last_{END_OVERLAP}"
AGG_DURATION_OVERLAP_S = f"max_{DURATION_OVERLAP}_s"
AGG_VALUE = f"agg_{VALUE}"
ATTRIBUTES = "attributes"
# Select required columns from input and rename them
# Fill recordtime_end, tmin, tmax and t0 with None if unspecified
lf = lazy_input.select(
primary_key,
pl.col(recordtime).alias(START),
(
pl.lit(None).alias(END)
if recordtime_end is None
else pl.col(recordtime_end).alias(END)
),
pl.col(value).alias(VALUE),
pl.lit(None).alias(TMIN) if tmin is None else pl.col(tmin).alias(TMIN),
pl.lit(None).alias(TMAX) if tmax is None else pl.col(tmax).alias(TMAX),
pl.lit(None).alias(T0) if t0 is None else pl.col(t0).alias(T0),
)
# impute recordtime_end if necessary
if recordtime_end is None:
# Sort and calculate end_time by shifting over partition
lf = lf.sort(primary_key, START).with_columns(
# NOTE: Create nulls for the last entry of each partition
# Will be dropped in the next step
pl.col(START)
.shift(-1)
.over(primary_key)
.alias(END)
)
# Enforce max gap treshold
if recordtime_end_gap_treshold:
if recordtime_end_gap_behavior == "drop":
lf = lf.remove(
pl.col(START)
.dt.offset_by(recordtime_end_gap_treshold)
.lt(pl.col(END))
)
else:
lf = lf.with_columns(
pl.min_horizontal(
pl.col(START).dt.offset_by(recordtime_end_gap_treshold),
pl.col(END),
).alias(END)
)
# Drop nulls and calculate duration
lf = lf.drop_nulls(
[primary_key, START, END, VALUE]
+ ([TMIN] if tmin is not None and t_missing_behavior == "drop" else [])
+ ([TMAX] if tmax is not None and t_missing_behavior == "drop" else [])
+ ([T0] if t0 is not None and t_missing_behavior == "drop" else [])
).with_columns(pl.col(END).sub(pl.col(START)).alias(DURATION))
if duration_empty_impute_duration:
lf = lf.with_columns(
pl.when(pl.col(DURATION).eq(0))
.then(pl.col(END).dt.offset_by(duration_empty_impute_duration))
.otherwise(pl.col(END))
.alias(END)
)
# Filter out empty intervals and intervals not intersecting [tmin, tmax]
lf_clipped = (
lf.filter(
pl.col(DURATION).gt(0)
& pl.col(START).lt(pl.col(TMAX)).or_(pl.col(TMAX).is_null())
& pl.col(END).gt(pl.col(TMIN)).or_(pl.col(TMIN).is_null())
)
.with_columns(
pl.col(START).clip(TMIN, TMAX).alias(START_CLIPPED_TMIN_TMAX),
pl.col(END).clip(TMIN, TMAX).alias(END_CLIPPED_TMIN_TMAX),
)
.drop(TMIN, TMAX)
.with_columns(
pl.col(END_CLIPPED_TMIN_TMAX)
.sub(pl.col(START_CLIPPED_TMIN_TMAX))
.alias(DURATION_CLIPPED_TMIN_TMAX)
)
)
# Calculate corrected values for intervals clipped to [tmin, tmax]
# Skip this step when tmin or tmax unspecified for performance
# It is practically a no-op in that case
corrected_value_expr: pl.Expr
if tmin is not None or tmax is not None:
corrected_value_expr = (
pl.col(VALUE)
.cast(pl.Float64)
.mul(pl.col(DURATION_CLIPPED_TMIN_TMAX))
.truediv(pl.col(DURATION))
)
else:
corrected_value_expr = pl.col(VALUE).cast(pl.Float64)
lf_clipped = (
lf_clipped
# Shouldn't be possible after the previous filter step
.filter(pl.col(DURATION_CLIPPED_TMIN_TMAX).gt(0)).with_columns(
corrected_value_expr.alias(VALUE_CLIPPED_TMIN_TMAX)
)
)
# NOTE: Empty string is also not valid
correction_expr: pl.Expr
datetime_ranges_expr: pl.Expr
intersection_condition_expr: pl.Expr
lf_distributed: pl.LazyFrame
if granularity:
# Use t0 as relative anchor for the buckets
if t0 is not None:
# Apply correction offset to align bins to a grid starting from t0
correction_expr = pl.col(T0).sub(pl.col(T0).dt.truncate(granularity))
datetime_ranges_expr = pl.datetime_ranges(
pl.col(START_CLIPPED_TMIN_TMAX)
.sub(correction_expr)
.dt.truncate(granularity),
pl.col(END_CLIPPED_TMIN_TMAX).sub(correction_expr),
granularity,
)
lf_distributed = (
# Generate buckets
lf_clipped.with_columns(
datetime_ranges_expr.alias(START_SPLIT)
).explode(START_SPLIT)
# Revert correction offset
.with_columns(
pl.col(START_SPLIT).add(correction_expr),
)
)
# Snap buckets to previous grid point
else:
datetime_ranges_expr = pl.datetime_ranges(
pl.col(START_CLIPPED_TMIN_TMAX).dt.truncate(granularity),
END_CLIPPED_TMIN_TMAX,
granularity,
)
lf_distributed = (
# Generate buckets
lf_clipped.with_columns(
datetime_ranges_expr.alias(START_SPLIT)
).explode(START_SPLIT)
)
lf_distributed = (
lf_distributed
# Calculate end_granularity
.with_columns(
pl.col(START_SPLIT).dt.offset_by(granularity).alias(END_SPLIT),
)
)
else:
# Split intervals containing t0 into two intervals
if t0 is not None:
# Find intervals containing t0
intersection_condition_expr = (
pl.col(START_CLIPPED_TMIN_TMAX)
.lt(pl.col(T0))
.and_(pl.col(END_CLIPPED_TMIN_TMAX).gt(pl.col(T0)))
)
# Split intervals containing t0
lf_distributed = lf_clipped.with_columns(
pl.when(intersection_condition_expr)
.then(
pl.concat_list(
pl.col(START_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime)),
pl.col(T0).cast(pl.List(pl.Datetime)),
)
)
.otherwise(pl.col(START_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime)))
.alias(START_SPLIT),
pl.when(intersection_condition_expr)
.then(
pl.concat_list(
pl.col(T0).cast(pl.List(pl.Datetime)),
pl.col(END_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime)),
)
)
.otherwise(pl.col(END_CLIPPED_TMIN_TMAX).cast(pl.List(pl.Datetime)))
.alias(END_SPLIT),
).explode(START_SPLIT, END_SPLIT)
# Do nothing: Split times are equal to previous times and not exploded
else:
# Use clipped times as bucket times
lf_distributed = (
# Generate buckets
lf_clipped.with_columns(
pl.col(START_CLIPPED_TMIN_TMAX).alias(START_SPLIT),
pl.col(END_CLIPPED_TMIN_TMAX).alias(END_SPLIT),
)
)
# Calculate corrected values for intervals clipped to [tmin, tmax]
# Skip this step when tmin or tmax unspecified for performance
# It is practically a no-op in that case
if granularity or t0 is not None:
corrected_value_expr = (
pl.col(VALUE_CLIPPED_TMIN_TMAX) # Already cast to Float64
.mul(pl.col(DURATION_OVERLAP))
.truediv(pl.col(DURATION_CLIPPED_TMIN_TMAX))
)
else:
corrected_value_expr = pl.col(VALUE_CLIPPED_TMIN_TMAX)
lf_distributed = (
lf_distributed.with_columns(
pl.max_horizontal(START_CLIPPED_TMIN_TMAX, START_SPLIT).alias(
START_OVERLAP
),
pl.min_horizontal(END_CLIPPED_TMIN_TMAX, END_SPLIT).alias(END_OVERLAP),
)
.with_columns(
pl.col(END_OVERLAP).sub(pl.col(START_OVERLAP)).alias(DURATION_OVERLAP)
)
.with_columns(corrected_value_expr.alias(VALUE_OVERLAP))
.select(
primary_key,
START_SPLIT,
END_SPLIT,
T0,
START_OVERLAP,
END_OVERLAP,
VALUE_OVERLAP,
)
)
# Importantly this data is not yet aggregated on the intervals...
# There can be multiple entries per interval
# Aggregate interval
def agg_interval(
df: pl.LazyFrame, more_group_by_cols: Iterable[pl.Expr | str] = []
):
return (
df.group_by(primary_key, *more_group_by_cols)
.agg(
pl.min(START_OVERLAP).alias(AGG_START_OVERLAP),
pl.max(END_OVERLAP).alias(AGG_END_OVERLAP),
(
pl.col(VALUE_OVERLAP).pipe(same_bin_aggregation).alias(AGG_VALUE)
if same_bin_aggregation
else pl.col(VALUE_OVERLAP).sum().alias(AGG_VALUE)
),
)
# Recalculate overlap_duration_s with new aggregated timestamps
.with_columns(
pl.col(AGG_END_OVERLAP)
.sub(pl.col(AGG_START_OVERLAP))
.dt.total_seconds()
.alias(AGG_DURATION_OVERLAP_S)
)
.sort(primary_key, AGG_START_OVERLAP)
)
group_cols: list[str | pl.Expr] = []
if granularity:
group_cols += [START_SPLIT, END_SPLIT] + ([T0] if t0 is not None else [])
elif t0 is not None:
group_cols += [
pl.when(pl.col(START_SPLIT).lt(pl.col(T0)))
.then(pl.lit(-1))
.otherwise(pl.lit(1))
.alias(BEFORE_OR_AFTER_T0_SIGNED),
]
agg_df = lf_distributed.pipe(agg_interval, group_cols)
extra_attributes: list[pl.Expr | str] = [
AGG_START_OVERLAP,
AGG_END_OVERLAP,
AGG_DURATION_OVERLAP_S,
]
if granularity and t0 is not None:
extra_attributes += [
pl.col(START_SPLIT)
.sub(pl.col(T0))
.truediv(pl.col(END_SPLIT).sub(pl.col(START_SPLIT)))
.alias(T0_GRANULARITY_DISTANCE)
]
elif t0 is not None:
extra_attributes += [BEFORE_OR_AFTER_T0_SIGNED]
agg_df_packed = agg_df.with_columns(
pl.struct(extra_attributes).alias(ATTRIBUTES)
).select(primary_key, *group_cols, AGG_VALUE, ATTRIBUTES)
return agg_df_packed if isinstance(data, pl.LazyFrame) else agg_df_packed.collect()