from __future__ import annotations
from datetime import datetime
from typing import Literal, TYPE_CHECKING
import numpy as np
import polars as pl
from scipy.stats import poisson as scipy_poisson
from scipy.stats import truncnorm as scipy_truncnorm
from scipy.stats._distn_infrastructure import rv_continuous_frozen
from corr_vars import logger, __
from corr_vars.sources.cub_hdp.extract import VariableLoader as CubHDPVariableLoader
from corr_vars.sources.cub_hdp_dummy.random import SEED
from corr_vars.core.variable import Variable as BaseVariable
from collections.abc import Callable
from corr_vars.definitions import ObsLevel
if TYPE_CHECKING:
from corr_vars.core.cohort import Cohort
from corr_vars.definitions import TimeAnchorColumn, VariableCallable
def VariableLoader(*args, **kwargs) -> BaseVariable:
"""Factory function to create the correct variable type based on the arguments."""
try:
var_type = kwargs.get("type", None)
except KeyError:
logger.error("CubHDP Dummy Variable requires the 'type' field.")
if var_type in ["dummy_dynamic"]:
kwargs.pop("type", None)
return DummyDynamic.from_time_window(*args, **kwargs)
else:
return CubHDPVariableLoader(*args, **kwargs)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fit_truncnorm(
mean: float,
q1: float,
q3: float,
value_min: float,
value_max: float,
) -> rv_continuous_frozen:
"""
Fit a truncated normal distribution from summary statistics.
Sigma is estimated from the IQR (robust to outliers):
sigma ≈ IQR / 1.35
Falls back to range / 6 (i.e. ±3-sigma covers the full range) if IQR is 0.
The distribution is then truncated to [value_min, value_max].
"""
sigma = (q3 - q1) / 1.35
if sigma <= 0:
sigma = max((value_max - value_min) / 6, 1e-6)
a = (value_min - mean) / sigma # lower truncation in normalised units
b = (value_max - mean) / sigma # upper truncation in normalised units
return scipy_truncnorm(
a, b, loc=mean, scale=sigma
) # pyright: ignore[reportReturnType]
def _sample_zero_truncated_poisson(
rng: np.random.Generator,
lambdas: np.ndarray,
) -> np.ndarray:
"""
Vectorised zero-truncated Poisson sampler via the inverse-CDF method.
Standard Poisson puts some probability mass at 0 (P(X=0) = e^(-λ)).
Zero-truncated Poisson conditions on X > 0, redistributing that mass
across the positive integers while preserving the shape of the tail.
Method — map a Uniform draw into the non-zero region of the Poisson CDF:
U ~ Uniform(0, 1)
U' = e^(-λ) + U × (1 - e^(-λ)) # shift into (F(0), 1]
k = Poisson_ppf(U', λ) # invert CDF → guaranteed k ≥ 1
Fully vectorised via ``scipy.stats.poisson.ppf``; handles a different λ
per patient in a single call with no rejection-sampling loop.
"""
p0 = np.exp(-lambdas) # P(X = 0) = e^(-λ) per patient
u = rng.random(size=len(lambdas))
u_shifted = p0 + u * (1.0 - p0) # map into (p0, 1] → skip zero
return scipy_poisson.ppf(u_shifted, mu=lambdas).astype(int)
# ---------------------------------------------------------------------------
# Main class
# ---------------------------------------------------------------------------
[docs]
class DummyDynamic(BaseVariable):
"""
Dummy dynamic variable that generates synthetic time-series data.
Supports four value types: numeric, boolean, string/categorical, datetime.
Patient selection and measurement frequency are controlled by two
independent parameters:
- ``missingness``: fraction of patients who have **no** measurements at all
(sampled once, randomly, before density is applied).
- ``density``: expected measurements per day for patients **with** data.
Uses a zero-truncated Poisson so selected patients always get ≥ 1 row,
and the count distribution is not distorted by the zero-inflation that
a plain Poisson would introduce.
Args:
var_name: Name of the variable.
tmin: Minimum time bound column.
tmax: Maximum time bound column.
py: Optional Python function to apply to the variable post-extraction.
data_type:
One of ``"numeric"``, ``"boolean"``, ``"string"``, ``"datetime"``.
mean: Mean of the distribution (numeric / datetime in epoch seconds).
value_min: Minimum value (numeric / datetime in epoch seconds).
value_max: Maximum value (numeric / datetime in epoch seconds).
q1: First quartile (numeric / datetime in epoch seconds).
q3: Third quartile (numeric / datetime in epoch seconds).
decimals: Round numeric values to this many decimal places (optional).
missingness:
Fraction of patients with **no** measurements (0.0–1.0).
0.0 → all patients have data; 1.0 → no patients have data.
Applied as a Bernoulli draw per patient before density sampling.
Default: 0.0.
density:
Expected measurements **per patient per day** for patients who
have data. Each selected patient's ZTP lambda =
density × patient_duration_days, guaranteeing ≥ 1 measurement.
Examples: 1.0 → daily, 0.143 → weekly, 24.0 → hourly.
true_freq:
Frequency of ``True`` values for ``data_type="boolean"`` (0–1).
Defaults to 0.5.
categories:
For ``data_type="string"``:
- ``list[str]``: categories sampled with uniform probability.
- ``dict[str, float]``: ``{category: weight}`` — weights may sum
to 1.0 *or* 100; they are normalised automatically.
observation_type:
``"snapshot"`` — each row has a single ``recordtime``.
``"interval"`` — each row also has a ``recordtime_end``, derived
by adding a random duration to ``recordtime``.
interval_min_s: Minimum interval duration in seconds (default: 60).
interval_max_s: Maximum interval duration in seconds (default: 86 400).
"""
def __init__(
self,
# core.variable.Variable — DO NOT CHANGE
var_name: str,
tmin: TimeAnchorColumn | None = None,
tmax: TimeAnchorColumn | None = None,
py: VariableCallable | None = None,
# ------------------------------------------------------------------ #
# Data type
data_type: Literal["numeric", "boolean", "string", "datetime"] | None = None,
# Numeric / datetime distribution stats
mean: float | None = None,
value_min: float | None = None,
value_max: float | None = None,
q1: float | None = None,
q3: float | None = None,
decimals: int | None = None,
# Missingness & measurement density — shared across all data types
missingness: float = 0.0,
density: float | None = None,
# Boolean
true_freq: float | None = None,
# String / categorical
categories: list[str] | dict[str, float] | None = None,
# Observation type & interval params
observation_type: Literal["snapshot", "interval"] = "snapshot",
interval_min_s: int = 60,
interval_max_s: int = 24 * 60 * 60,
):
super().__init__(
var_name=var_name,
dynamic=True,
requires=[],
tmin=tmin,
tmax=tmax,
py=py,
py_ready_polars=True,
cleaning=None,
)
self.data_type = data_type
# Numeric / datetime
self.mean = mean
self.value_min = value_min
self.value_max = value_max
self.q1 = q1
self.q3 = q3
self.decimals = decimals
# Missingness & density
self.missingness = missingness
self.density = density if density is not None else 1.0
# Boolean
self.true_freq = true_freq if true_freq is not None else 0.5
# String
self.categories = categories
# Observation type
self.observation_type = observation_type
self.interval_min_s = interval_min_s
self.interval_max_s = interval_max_s
# ---------------------------------------------------------------------- #
# Protocol method
# ---------------------------------------------------------------------- #
# ---------------------------------------------------------------------- #
# Core generation logic
# ---------------------------------------------------------------------- #
def _extract_dummy_data(self, cohort: Cohort) -> pl.DataFrame:
rng = np.random.default_rng(seed=SEED)
obs_level: ObsLevel = cohort.obs_level
pk = obs_level.primary_key
unique_ids = cohort.obs[pk].unique(maintain_order=True).to_numpy()
# --- Step 1: Resolve per-patient time windows --------------------- #
# _add_time_window requires obs_level.primary_key to be present in
# self.data; it joins case_tmin/case_tmax onto the dataframe in place.
self.data = pl.DataFrame({pk: unique_ids})
self._add_time_window(cohort) # → adds "case_tmin", "case_tmax"
# --- Step 2: Apply missingness ------------------------------------ #
# Bernoulli draw per patient: True → patient has measurements.
# This is independent of duration, purely a prevalence parameter.
has_data = rng.random(size=len(unique_ids)) >= self.missingness
if not has_data.any():
logger.warning(
__(
"[{var_name}] No patients selected after missingness filter "
"(missingness={missingness}). Returning empty DataFrame.",
var_name=self.var_name,
missingness=self.missingness,
)
)
return self._empty_dataframe(obs_level)
selected_ids = unique_ids[has_data]
selected_data = self.data.filter(pl.Series(has_data))
# --- Step 3: Per-patient zero-truncated Poisson count ------------- #
# lambda_i = density [meas/day] × duration_i [days]
# ZTP guarantees ≥ 1 measurement for every selected patient, while
# preserving the conditional count distribution (no zero-inflation).
durations_days = selected_data.select(
pl.col("case_tmax")
.sub(pl.col("case_tmin"))
.dt.total_seconds()
.truediv(86_400)
.alias("dur_day")
)["dur_day"].to_numpy()
durations_days = np.maximum(durations_days, 0.0) # guard: tmax >= tmin
lambdas = self.density * durations_days # per-patient ZTP lambda
counts_per_patient = _sample_zero_truncated_poisson(rng, lambdas)
total_rows = int(counts_per_patient.sum())
# --- Step 4: Expand arrays to one-row-per-measurement ------------- #
repeated_ids = np.repeat(selected_ids, counts_per_patient)
# Extract per-patient bounds as numpy arrays for vectorised ops
tmins = selected_data["case_tmin"].dt.epoch("us").to_numpy()
tmaxs = selected_data["case_tmax"].dt.epoch("us").to_numpy()
# Keep datetime arrays for the output columns
repeated_tmins = np.repeat(tmins, counts_per_patient)
repeated_tmaxs = np.repeat(tmaxs, counts_per_patient)
# --- Step 5: Generate recordtimes within per-patient [tmin, tmax] - #
recordtime = rng.uniform(low=repeated_tmins, high=repeated_tmaxs).astype(
"datetime64[us]"
)
# --- Step 6: Generate values -------------------------------------- #
value_gen = self._make_value_generator(rng)
values = value_gen(total_rows)
# --- Step 7: Interval end time (optional) ------------------------- #
if self.observation_type == "interval":
durations_s = rng.integers(
self.interval_min_s, self.interval_max_s + 1, size=total_rows
)
recordtime_end = recordtime + durations_s.astype("timedelta64[s]")
else:
recordtime_end = np.full(total_rows, None)
# --- Step 8: Assemble final dataframe ----------------------------- #
# case_tmin / case_tmax carried forward for _timefilter downstream.
df = pl.DataFrame(
{
pk: repeated_ids,
"recordtime": recordtime,
"recordtime_end": recordtime_end,
"value": values,
"value_unit": np.full(total_rows, None),
"case_tmin": repeated_tmins,
"case_tmax": repeated_tmaxs,
},
schema_overrides={
pk: pl.Utf8,
"recordtime": pl.Datetime,
"recordtime_end": pl.Datetime,
"value_unit": pl.Utf8,
"case_tmin": pl.Datetime,
"case_tmax": pl.Datetime,
},
)
# --- Step 9: Cleanup ---------------------------------------------- #
df = df.with_columns(pl.selectors.datetime().dt.truncate("1s"))
if self.decimals is not None:
df = df.with_columns(pl.col("value").round(self.decimals))
return df
# ---------------------------------------------------------------------- #
# Value generators — one per data type
# ---------------------------------------------------------------------- #
def _make_value_generator(
self, rng: np.random.Generator
) -> Callable[[int], np.ndarray]:
"""Return a closure that generates *n* synthetic values."""
if self.data_type == "numeric":
if self.mean is None:
raise ValueError(
f"[{self.var_name}] `mean` must be set for data_type='numeric'."
)
if self.q1 is None:
raise ValueError(
f"[{self.var_name}] `q1` must be set for data_type='numeric'."
)
if self.q3 is None:
raise ValueError(
f"[{self.var_name}] `q3` must be set for data_type='numeric'."
)
if self.value_min is None:
raise ValueError(
f"[{self.var_name}] `value_min` must be set for data_type='numeric'."
)
if self.value_max is None:
raise ValueError(
f"[{self.var_name}] `value_max` must be set for data_type='numeric'."
)
dist = _fit_truncnorm(
mean=self.mean,
q1=self.q1,
q3=self.q3,
value_min=self.value_min,
value_max=self.value_max,
)
def generate(n: int) -> np.ndarray:
return dist.rvs(size=n, random_state=rng)
elif self.data_type == "datetime":
if self.mean is None:
raise ValueError(
f"[{self.var_name}] `mean` must be set for data_type='datetime'."
)
if self.q1 is None:
raise ValueError(
f"[{self.var_name}] `q1` must be set for data_type='datetime'."
)
if self.q3 is None:
raise ValueError(
f"[{self.var_name}] `q3` must be set for data_type='datetime'."
)
if self.value_min is None:
raise ValueError(
f"[{self.var_name}] `value_min` must be set for data_type='datetime'."
)
if self.value_max is None:
raise ValueError(
f"[{self.var_name}] `value_max` must be set for data_type='datetime'."
)
# Epoch-seconds treated identically to numeric; converted back
# to datetime objects afterwards.
dist = _fit_truncnorm(
mean=self.mean,
q1=self.q1,
q3=self.q3,
value_min=self.value_min,
value_max=self.value_max,
)
def generate(n: int) -> np.ndarray:
epoch_seconds = dist.rvs(size=n, random_state=rng)
return np.array(
[datetime.fromtimestamp(float(s)) for s in epoch_seconds],
dtype="object",
)
elif self.data_type == "boolean":
def generate(n: int) -> np.ndarray:
return rng.random(size=n) < self.true_freq
elif self.data_type == "string":
cats, weights = self._parse_categories()
def generate(n: int) -> np.ndarray:
return rng.choice(cats, size=n, p=weights)
else:
raise ValueError(
f"[{self.var_name}] Unsupported data_type: {self.data_type!r}. "
"Choose one of: 'numeric', 'boolean', 'string', 'datetime'."
)
return generate
# ---------------------------------------------------------------------- #
# Helpers
# ---------------------------------------------------------------------- #
def _parse_categories(self) -> tuple[list[str], np.ndarray]:
"""Parse the ``categories`` argument into (labels, normalised weights)."""
if self.categories is None:
raise ValueError(
f"[{self.var_name}] `categories` must be set for data_type='string'."
)
if isinstance(self.categories, list):
cats = self.categories
weights = np.ones(len(cats)) / len(cats) # uniform
else:
cats = list(self.categories.keys())
raw = np.array(list(self.categories.values()), dtype=float)
weights = raw / raw.sum() # normalise — handles both 0–1 and 0–100
return cats, weights
def _empty_dataframe(self, obs_level: ObsLevel) -> pl.DataFrame:
"""Return a correctly-typed empty DataFrame when no rows are generated."""
return pl.DataFrame(
{
obs_level.primary_key: pl.Series([], dtype=pl.Utf8),
"recordtime": pl.Series([], dtype=pl.Datetime),
"recordtime_end": pl.Series([], dtype=pl.Datetime),
"value": pl.Series([], dtype=pl.Utf8),
"value_unit": pl.Series([], dtype=pl.Utf8),
"case_tmin": pl.Series([], dtype=pl.Datetime),
"case_tmax": pl.Series([], dtype=pl.Datetime),
}
)