Source code for corr_vars.utils.time

from __future__ import annotations

from datetime import datetime, date, timezone, timedelta
from enum import Flag, auto

import polars as pl

from corr_vars.definitions.typing import is_time_anchor_column, TimeAnchorColumn
from polars._typing import ClosedInterval
from polars.expr.whenthen import Then, ChainedThen
from typing import TypeAlias

DateLiteral: TypeAlias = datetime | date | int


[docs] class TimeAnchor: column: str | None delta: str | None expr: pl.Expr def __init__(self, bound: TimeAnchorColumn | DateLiteral) -> None: if ( not is_time_anchor_column(bound) and not isinstance(bound, datetime) and not isinstance(bound, date) ): raise TypeError( "Expected TimeAnchorColumn to be a string or a tuple of two strings or a polars expression." ) if isinstance(bound, datetime) or isinstance(bound, date): self.column = None self.delta = None self.expr = pl.lit(bound) elif isinstance(bound, str): self.column = bound self.delta = None self.expr = pl.col(self.column) else: self.column, delta = bound self.delta = delta[1:] if delta[0] == "+" else delta self.expr = pl.col(self.column).dt.offset_by(self.delta) self.expr = self.expr.alias("time") def __repr__(self) -> str: if self.column is None or self.delta is None: return self.column or ( f"Literal: {pl.select(self.expr).item()}" if self.expr.meta.is_literal() else ", ".join(self.expr.meta.root_names()) ) else: op = "-" if self.delta[0] == "-" else "+" rest = self.delta[1:] if self.delta[0] == "-" else self.delta return f"{self.column} {op} {rest}" def __eq__(self, other: object) -> bool: if not isinstance(other, TimeAnchor): return False return ( self.column == other.column and self.delta == other.delta and self.expr.meta.eq(other.expr) ) @property def is_relative(self) -> bool: return self.column is not None and self.delta is not None @property def is_literal(self) -> bool: return self.column is None and self.delta is None
[docs] class TimeWindow: tmin: TimeAnchor tmax: TimeAnchor user_specified_tmin: bool user_specified_tmax: bool def __init__( self, tmin: TimeAnchor | TimeAnchorColumn | DateLiteral | None = None, tmax: TimeAnchor | TimeAnchorColumn | DateLiteral | None = None, user_specified: bool | tuple[bool, bool] = False, ) -> None: self.tmin = ( tmin if isinstance(tmin, TimeAnchor) else TimeAnchor(tmin or datetime.min.replace(tzinfo=None)) ) self.tmax = ( tmax if isinstance(tmax, TimeAnchor) else TimeAnchor( tmax or datetime.now(tz=timezone(timedelta(hours=1))).replace(tzinfo=None) ) ) self.user_specified_tmin = ( user_specified[0] if isinstance(user_specified, tuple) else user_specified ) self.user_specified_tmax = ( user_specified[1] if isinstance(user_specified, tuple) else user_specified ) def __repr__(self) -> str: return f"TimeWindow(tmin={self.tmin}, tmax={self.tmax})" def __eq__(self, other: object) -> bool: if not isinstance(other, TimeWindow): return False return self.tmin == other.tmin and self.tmax == other.tmax @property def user_specified(self) -> bool: return self.user_specified_tmin or self.user_specified_tmax #################### # Time Anchors # #################### @property def tmin_expr(self, alias: str = "tmin") -> pl.Expr: return self.tmin.expr.alias(alias) @property def tmax_expr(self, alias: str = "tmax") -> pl.Expr: return self.tmax.expr.alias(alias) #################### # Time Concepts # #################### # TODO: Use or discard @property def tmid_expr(self, alias: str = "tmid") -> pl.Expr: return self.tmin_expr.add(self.duration_expr.truediv(2)).alias(alias) # TODO: Use or discard @property def duration_expr(self, alias: str = "duration") -> pl.Expr: return (self.tmax_expr - self.tmin_expr).alias(alias) @property def is_relative(self) -> bool: return self.tmin.is_relative or self.tmax.is_relative @property def is_literal(self) -> bool: return self.tmin.is_literal and self.tmax.is_literal #################### # Validity # #################### @property def any_timebound_in_year_9999_expr(self) -> pl.Expr: return pl.any_horizontal( self.tmin_expr.dt.year().eq(pl.lit(9999)).fill_null(False), self.tmax_expr.dt.year().eq(pl.lit(9999)).fill_null(False), ) @property def any_timebound_null_expr(self) -> pl.Expr: return pl.any_horizontal( self.tmin_expr.is_null(), self.tmax_expr.is_null(), ) @property def tmin_after_tmax_expr(self) -> pl.Expr: return self.tmin_expr > self.tmax_expr @property def any_timebound_invalid_expr(self) -> pl.Expr: return pl.any_horizontal( self.any_timebound_null_expr, self.any_timebound_in_year_9999_expr, self.tmin_after_tmax_expr, )
[docs] class WindowRelation(Flag): CONTAINED = auto() # fully inside [tmin, tmax] OVERLAPS_LEFT = auto() # starts before tmin and ends inside OVERLAPS_RIGHT = auto() # starts inside and ends after tmax COVERS_WINDOW = auto() # spans over both boundaries # Common combinations ANY_LEFT_OVERLAP = OVERLAPS_LEFT | CONTAINED | COVERS_WINDOW ANY_RIGHT_OVERLAP = OVERLAPS_RIGHT | CONTAINED | COVERS_WINDOW ANY_OVERLAP = CONTAINED | OVERLAPS_LEFT | OVERLAPS_RIGHT | COVERS_WINDOW
[docs] class WindowOverlap:
[docs] @staticmethod def contains_expr( record: TimeAnchor, window: TimeWindow, closed: ClosedInterval = "both" ) -> pl.Expr: return record.expr.is_between(window.tmin_expr, window.tmax_expr, closed=closed)
[docs] @staticmethod def relation_expr( record: TimeWindow, window: TimeWindow, ) -> dict[WindowRelation, pl.Expr]: start = record.tmin_expr end = record.tmax_expr tmin = window.tmin_expr tmax = window.tmax_expr return { WindowRelation.CONTAINED: (start >= tmin) & (end <= tmax), WindowRelation.OVERLAPS_LEFT: (start < tmin) & (end >= tmin) & (end <= tmax), WindowRelation.OVERLAPS_RIGHT: (start >= tmin) & (start <= tmax) & (end > tmax), WindowRelation.COVERS_WINDOW: (start < tmin) & (end > tmax), }
[docs] @staticmethod def overlap_expr( record: TimeWindow, window: TimeWindow, relation: WindowRelation, ) -> pl.Expr: expr_map = WindowOverlap.relation_expr(record, window) expr = None for flag, cond in expr_map.items(): if relation & flag: expr = cond if expr is None else expr | cond if expr is None: raise ValueError(f"Invalid relation: {relation}") return expr
# TODO: Use or discard
[docs] @staticmethod def classify_expr( record: TimeWindow, window: TimeWindow, *, default: WindowRelation | None = None, ) -> pl.Expr: exprs = WindowOverlap.relation_expr(record, window) when: Then | ChainedThen = pl.when(exprs[WindowRelation.CONTAINED]).then( pl.lit(WindowRelation.CONTAINED.value) ) for rel in ( WindowRelation.OVERLAPS_LEFT, WindowRelation.OVERLAPS_RIGHT, WindowRelation.COVERS_WINDOW, ): when = when.when(exprs[rel]).then(pl.lit(rel.value)) if default is not None: expr = when.otherwise(pl.lit(default.value)) else: expr = when.otherwise(pl.lit(None)) return expr
[docs] def compatible_with_time_anchor( df: pl.DataFrame | pl.LazyFrame, time_anchor: TimeAnchor ) -> bool: if time_anchor.is_literal: return True column = time_anchor.column or "" if isinstance(df, pl.DataFrame): return column in df.columns else: return column in df.collect_schema().names()
[docs] def compatible_with_time_window( df: pl.DataFrame | pl.LazyFrame, time_window: TimeWindow ) -> bool: return compatible_with_time_anchor( df, time_window.tmin ) and compatible_with_time_anchor(df, time_window.tmax)