from __future__ import annotations
from datetime import datetime, date, timezone, timedelta
from enum import Flag, auto
import polars as pl
from corr_vars.definitions.typing import is_time_anchor_column, TimeAnchorColumn
from polars._typing import ClosedInterval
from polars.expr.whenthen import Then, ChainedThen
from typing import TypeAlias
DateLiteral: TypeAlias = datetime | date | int
[docs]
class TimeAnchor:
column: str | None
delta: str | None
expr: pl.Expr
def __init__(self, bound: TimeAnchorColumn | DateLiteral) -> None:
if (
not is_time_anchor_column(bound)
and not isinstance(bound, datetime)
and not isinstance(bound, date)
):
raise TypeError(
"Expected TimeAnchorColumn to be a string or a tuple of two strings or a polars expression."
)
if isinstance(bound, datetime) or isinstance(bound, date):
self.column = None
self.delta = None
self.expr = pl.lit(bound)
elif isinstance(bound, str):
self.column = bound
self.delta = None
self.expr = pl.col(self.column)
else:
self.column, delta = bound
self.delta = delta[1:] if delta[0] == "+" else delta
self.expr = pl.col(self.column).dt.offset_by(self.delta)
self.expr = self.expr.alias("time")
def __repr__(self) -> str:
if self.column is None or self.delta is None:
return self.column or (
f"Literal: {pl.select(self.expr).item()}"
if self.expr.meta.is_literal()
else ", ".join(self.expr.meta.root_names())
)
else:
op = "-" if self.delta[0] == "-" else "+"
rest = self.delta[1:] if self.delta[0] == "-" else self.delta
return f"{self.column} {op} {rest}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, TimeAnchor):
return False
return (
self.column == other.column
and self.delta == other.delta
and self.expr.meta.eq(other.expr)
)
@property
def is_relative(self) -> bool:
return self.column is not None and self.delta is not None
@property
def is_literal(self) -> bool:
return self.column is None and self.delta is None
[docs]
class TimeWindow:
tmin: TimeAnchor
tmax: TimeAnchor
user_specified_tmin: bool
user_specified_tmax: bool
def __init__(
self,
tmin: TimeAnchor | TimeAnchorColumn | DateLiteral | None = None,
tmax: TimeAnchor | TimeAnchorColumn | DateLiteral | None = None,
user_specified: bool | tuple[bool, bool] = False,
) -> None:
self.tmin = (
tmin
if isinstance(tmin, TimeAnchor)
else TimeAnchor(tmin or datetime.min.replace(tzinfo=None))
)
self.tmax = (
tmax
if isinstance(tmax, TimeAnchor)
else TimeAnchor(
tmax
or datetime.now(tz=timezone(timedelta(hours=1))).replace(tzinfo=None)
)
)
self.user_specified_tmin = (
user_specified[0] if isinstance(user_specified, tuple) else user_specified
)
self.user_specified_tmax = (
user_specified[1] if isinstance(user_specified, tuple) else user_specified
)
def __repr__(self) -> str:
return f"TimeWindow(tmin={self.tmin}, tmax={self.tmax})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, TimeWindow):
return False
return self.tmin == other.tmin and self.tmax == other.tmax
@property
def user_specified(self) -> bool:
return self.user_specified_tmin or self.user_specified_tmax
####################
# Time Anchors #
####################
@property
def tmin_expr(self, alias: str = "tmin") -> pl.Expr:
return self.tmin.expr.alias(alias)
@property
def tmax_expr(self, alias: str = "tmax") -> pl.Expr:
return self.tmax.expr.alias(alias)
####################
# Time Concepts #
####################
# TODO: Use or discard
@property
def tmid_expr(self, alias: str = "tmid") -> pl.Expr:
return self.tmin_expr.add(self.duration_expr.truediv(2)).alias(alias)
# TODO: Use or discard
@property
def duration_expr(self, alias: str = "duration") -> pl.Expr:
return (self.tmax_expr - self.tmin_expr).alias(alias)
@property
def is_relative(self) -> bool:
return self.tmin.is_relative or self.tmax.is_relative
@property
def is_literal(self) -> bool:
return self.tmin.is_literal and self.tmax.is_literal
####################
# Validity #
####################
@property
def any_timebound_in_year_9999_expr(self) -> pl.Expr:
return pl.any_horizontal(
self.tmin_expr.dt.year().eq(pl.lit(9999)).fill_null(False),
self.tmax_expr.dt.year().eq(pl.lit(9999)).fill_null(False),
)
@property
def any_timebound_null_expr(self) -> pl.Expr:
return pl.any_horizontal(
self.tmin_expr.is_null(),
self.tmax_expr.is_null(),
)
@property
def tmin_after_tmax_expr(self) -> pl.Expr:
return self.tmin_expr > self.tmax_expr
@property
def any_timebound_invalid_expr(self) -> pl.Expr:
return pl.any_horizontal(
self.any_timebound_null_expr,
self.any_timebound_in_year_9999_expr,
self.tmin_after_tmax_expr,
)
[docs]
class WindowRelation(Flag):
CONTAINED = auto() # fully inside [tmin, tmax]
OVERLAPS_LEFT = auto() # starts before tmin and ends inside
OVERLAPS_RIGHT = auto() # starts inside and ends after tmax
COVERS_WINDOW = auto() # spans over both boundaries
# Common combinations
ANY_LEFT_OVERLAP = OVERLAPS_LEFT | CONTAINED | COVERS_WINDOW
ANY_RIGHT_OVERLAP = OVERLAPS_RIGHT | CONTAINED | COVERS_WINDOW
ANY_OVERLAP = CONTAINED | OVERLAPS_LEFT | OVERLAPS_RIGHT | COVERS_WINDOW
[docs]
class WindowOverlap:
[docs]
@staticmethod
def contains_expr(
record: TimeAnchor, window: TimeWindow, closed: ClosedInterval = "both"
) -> pl.Expr:
return record.expr.is_between(window.tmin_expr, window.tmax_expr, closed=closed)
[docs]
@staticmethod
def relation_expr(
record: TimeWindow,
window: TimeWindow,
) -> dict[WindowRelation, pl.Expr]:
start = record.tmin_expr
end = record.tmax_expr
tmin = window.tmin_expr
tmax = window.tmax_expr
return {
WindowRelation.CONTAINED: (start >= tmin) & (end <= tmax),
WindowRelation.OVERLAPS_LEFT: (start < tmin)
& (end >= tmin)
& (end <= tmax),
WindowRelation.OVERLAPS_RIGHT: (start >= tmin)
& (start <= tmax)
& (end > tmax),
WindowRelation.COVERS_WINDOW: (start < tmin) & (end > tmax),
}
[docs]
@staticmethod
def overlap_expr(
record: TimeWindow,
window: TimeWindow,
relation: WindowRelation,
) -> pl.Expr:
expr_map = WindowOverlap.relation_expr(record, window)
expr = None
for flag, cond in expr_map.items():
if relation & flag:
expr = cond if expr is None else expr | cond
if expr is None:
raise ValueError(f"Invalid relation: {relation}")
return expr
# TODO: Use or discard
[docs]
@staticmethod
def classify_expr(
record: TimeWindow,
window: TimeWindow,
*,
default: WindowRelation | None = None,
) -> pl.Expr:
exprs = WindowOverlap.relation_expr(record, window)
when: Then | ChainedThen = pl.when(exprs[WindowRelation.CONTAINED]).then(
pl.lit(WindowRelation.CONTAINED.value)
)
for rel in (
WindowRelation.OVERLAPS_LEFT,
WindowRelation.OVERLAPS_RIGHT,
WindowRelation.COVERS_WINDOW,
):
when = when.when(exprs[rel]).then(pl.lit(rel.value))
if default is not None:
expr = when.otherwise(pl.lit(default.value))
else:
expr = when.otherwise(pl.lit(None))
return expr
[docs]
def compatible_with_time_anchor(
df: pl.DataFrame | pl.LazyFrame, time_anchor: TimeAnchor
) -> bool:
if time_anchor.is_literal:
return True
column = time_anchor.column or ""
if isinstance(df, pl.DataFrame):
return column in df.columns
else:
return column in df.collect_schema().names()
[docs]
def compatible_with_time_window(
df: pl.DataFrame | pl.LazyFrame, time_window: TimeWindow
) -> bool:
return compatible_with_time_anchor(
df, time_window.tmin
) and compatible_with_time_anchor(df, time_window.tmax)