Source code for corr_vars.sources.cub_hdp_dummy.extract

from __future__ import annotations

from datetime import datetime
from typing import Literal, TYPE_CHECKING

import numpy as np
import polars as pl
from scipy.stats import poisson as scipy_poisson
from scipy.stats import truncnorm as scipy_truncnorm
from scipy.stats._distn_infrastructure import rv_continuous_frozen

from corr_vars import logger, __
from corr_vars.sources.cub_hdp.extract import VariableLoader as CubHDPVariableLoader
from corr_vars.sources.cub_hdp_dummy.random import SEED
from corr_vars.core.variable import Variable as BaseVariable
from collections.abc import Callable
from corr_vars.definitions import ObsLevel

if TYPE_CHECKING:
    from corr_vars.core.cohort import Cohort
    from corr_vars.definitions import TimeAnchorColumn, VariableCallable


def VariableLoader(*args, **kwargs) -> BaseVariable:
    """Factory function to create the correct variable type based on the arguments."""
    try:
        var_type = kwargs.get("type", None)
    except KeyError:
        logger.error("CubHDP Dummy Variable requires the 'type' field.")

    if var_type in ["dummy_dynamic"]:
        kwargs.pop("type", None)
        return DummyDynamic.from_time_window(*args, **kwargs)
    else:
        return CubHDPVariableLoader(*args, **kwargs)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _fit_truncnorm(
    mean: float,
    q1: float,
    q3: float,
    value_min: float,
    value_max: float,
) -> rv_continuous_frozen:
    """
    Fit a truncated normal distribution from summary statistics.

    Sigma is estimated from the IQR (robust to outliers):
        sigma ≈ IQR / 1.35
    Falls back to range / 6 (i.e. ±3-sigma covers the full range) if IQR is 0.
    The distribution is then truncated to [value_min, value_max].
    """
    sigma = (q3 - q1) / 1.35
    if sigma <= 0:
        sigma = max((value_max - value_min) / 6, 1e-6)

    a = (value_min - mean) / sigma  # lower truncation in normalised units
    b = (value_max - mean) / sigma  # upper truncation in normalised units
    return scipy_truncnorm(
        a, b, loc=mean, scale=sigma
    )  # pyright: ignore[reportReturnType]


def _sample_zero_truncated_poisson(
    rng: np.random.Generator,
    lambdas: np.ndarray,
) -> np.ndarray:
    """
    Vectorised zero-truncated Poisson sampler via the inverse-CDF method.

    Standard Poisson puts some probability mass at 0 (P(X=0) = e^(-λ)).
    Zero-truncated Poisson conditions on X > 0, redistributing that mass
    across the positive integers while preserving the shape of the tail.

    Method — map a Uniform draw into the non-zero region of the Poisson CDF:
        U  ~ Uniform(0, 1)
        U' = e^(-λ) + U × (1 - e^(-λ))     # shift into (F(0), 1]
        k  = Poisson_ppf(U', λ)              # invert CDF → guaranteed k ≥ 1

    Fully vectorised via ``scipy.stats.poisson.ppf``; handles a different λ
    per patient in a single call with no rejection-sampling loop.
    """
    p0 = np.exp(-lambdas)  # P(X = 0) = e^(-λ) per patient
    u = rng.random(size=len(lambdas))
    u_shifted = p0 + u * (1.0 - p0)  # map into (p0, 1] → skip zero
    return scipy_poisson.ppf(u_shifted, mu=lambdas).astype(int)


# ---------------------------------------------------------------------------
# Main class
# ---------------------------------------------------------------------------


[docs] class DummyDynamic(BaseVariable): """ Dummy dynamic variable that generates synthetic time-series data. Supports four value types: numeric, boolean, string/categorical, datetime. Patient selection and measurement frequency are controlled by two independent parameters: - ``missingness``: fraction of patients who have **no** measurements at all (sampled once, randomly, before density is applied). - ``density``: expected measurements per day for patients **with** data. Uses a zero-truncated Poisson so selected patients always get ≥ 1 row, and the count distribution is not distorted by the zero-inflation that a plain Poisson would introduce. Args: var_name: Name of the variable. tmin: Minimum time bound column. tmax: Maximum time bound column. py: Optional Python function to apply to the variable post-extraction. data_type: One of ``"numeric"``, ``"boolean"``, ``"string"``, ``"datetime"``. mean: Mean of the distribution (numeric / datetime in epoch seconds). value_min: Minimum value (numeric / datetime in epoch seconds). value_max: Maximum value (numeric / datetime in epoch seconds). q1: First quartile (numeric / datetime in epoch seconds). q3: Third quartile (numeric / datetime in epoch seconds). decimals: Round numeric values to this many decimal places (optional). missingness: Fraction of patients with **no** measurements (0.0–1.0). 0.0 → all patients have data; 1.0 → no patients have data. Applied as a Bernoulli draw per patient before density sampling. Default: 0.0. density: Expected measurements **per patient per day** for patients who have data. Each selected patient's ZTP lambda = density × patient_duration_days, guaranteeing ≥ 1 measurement. Examples: 1.0 → daily, 0.143 → weekly, 24.0 → hourly. true_freq: Frequency of ``True`` values for ``data_type="boolean"`` (0–1). Defaults to 0.5. categories: For ``data_type="string"``: - ``list[str]``: categories sampled with uniform probability. - ``dict[str, float]``: ``{category: weight}`` — weights may sum to 1.0 *or* 100; they are normalised automatically. observation_type: ``"snapshot"`` — each row has a single ``recordtime``. ``"interval"`` — each row also has a ``recordtime_end``, derived by adding a random duration to ``recordtime``. interval_min_s: Minimum interval duration in seconds (default: 60). interval_max_s: Maximum interval duration in seconds (default: 86 400). """ def __init__( self, # core.variable.Variable — DO NOT CHANGE var_name: str, tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, py: VariableCallable | None = None, # ------------------------------------------------------------------ # # Data type data_type: Literal["numeric", "boolean", "string", "datetime"] | None = None, # Numeric / datetime distribution stats mean: float | None = None, value_min: float | None = None, value_max: float | None = None, q1: float | None = None, q3: float | None = None, decimals: int | None = None, # Missingness & measurement density — shared across all data types missingness: float = 0.0, density: float | None = None, # Boolean true_freq: float | None = None, # String / categorical categories: list[str] | dict[str, float] | None = None, # Observation type & interval params observation_type: Literal["snapshot", "interval"] = "snapshot", interval_min_s: int = 60, interval_max_s: int = 24 * 60 * 60, ): super().__init__( var_name=var_name, dynamic=True, requires=[], tmin=tmin, tmax=tmax, py=py, py_ready_polars=True, cleaning=None, ) self.data_type = data_type # Numeric / datetime self.mean = mean self.value_min = value_min self.value_max = value_max self.q1 = q1 self.q3 = q3 self.decimals = decimals # Missingness & density self.missingness = missingness self.density = density if density is not None else 1.0 # Boolean self.true_freq = true_freq if true_freq is not None else 0.5 # String self.categories = categories # Observation type self.observation_type = observation_type self.interval_min_s = interval_min_s self.interval_max_s = interval_max_s # ---------------------------------------------------------------------- # # Protocol method # ---------------------------------------------------------------------- #
[docs] def extract(self, cohort: Cohort) -> pl.DataFrame: # _add_time_window is called inside _extract_dummy_data so that # per-patient tmin/tmax are available during row generation. self.data = self._extract_dummy_data(cohort) self._call_var_function(cohort) self._timefilter() self._apply_cleaning() return self.data
# ---------------------------------------------------------------------- # # Core generation logic # ---------------------------------------------------------------------- # def _extract_dummy_data(self, cohort: Cohort) -> pl.DataFrame: rng = np.random.default_rng(seed=SEED) obs_level: ObsLevel = cohort.obs_level pk = obs_level.primary_key unique_ids = cohort.obs[pk].unique(maintain_order=True).to_numpy() # --- Step 1: Resolve per-patient time windows --------------------- # # _add_time_window requires obs_level.primary_key to be present in # self.data; it joins case_tmin/case_tmax onto the dataframe in place. self.data = pl.DataFrame({pk: unique_ids}) self._add_time_window(cohort) # → adds "case_tmin", "case_tmax" # --- Step 2: Apply missingness ------------------------------------ # # Bernoulli draw per patient: True → patient has measurements. # This is independent of duration, purely a prevalence parameter. has_data = rng.random(size=len(unique_ids)) >= self.missingness if not has_data.any(): logger.warning( __( "[{var_name}] No patients selected after missingness filter " "(missingness={missingness}). Returning empty DataFrame.", var_name=self.var_name, missingness=self.missingness, ) ) return self._empty_dataframe(obs_level) selected_ids = unique_ids[has_data] selected_data = self.data.filter(pl.Series(has_data)) # --- Step 3: Per-patient zero-truncated Poisson count ------------- # # lambda_i = density [meas/day] × duration_i [days] # ZTP guarantees ≥ 1 measurement for every selected patient, while # preserving the conditional count distribution (no zero-inflation). durations_days = selected_data.select( pl.col("case_tmax") .sub(pl.col("case_tmin")) .dt.total_seconds() .truediv(86_400) .alias("dur_day") )["dur_day"].to_numpy() durations_days = np.maximum(durations_days, 0.0) # guard: tmax >= tmin lambdas = self.density * durations_days # per-patient ZTP lambda counts_per_patient = _sample_zero_truncated_poisson(rng, lambdas) total_rows = int(counts_per_patient.sum()) # --- Step 4: Expand arrays to one-row-per-measurement ------------- # repeated_ids = np.repeat(selected_ids, counts_per_patient) # Extract per-patient bounds as numpy arrays for vectorised ops tmins = selected_data["case_tmin"].dt.epoch("us").to_numpy() tmaxs = selected_data["case_tmax"].dt.epoch("us").to_numpy() # Keep datetime arrays for the output columns repeated_tmins = np.repeat(tmins, counts_per_patient) repeated_tmaxs = np.repeat(tmaxs, counts_per_patient) # --- Step 5: Generate recordtimes within per-patient [tmin, tmax] - # recordtime = rng.uniform(low=repeated_tmins, high=repeated_tmaxs).astype( "datetime64[us]" ) # --- Step 6: Generate values -------------------------------------- # value_gen = self._make_value_generator(rng) values = value_gen(total_rows) # --- Step 7: Interval end time (optional) ------------------------- # if self.observation_type == "interval": durations_s = rng.integers( self.interval_min_s, self.interval_max_s + 1, size=total_rows ) recordtime_end = recordtime + durations_s.astype("timedelta64[s]") else: recordtime_end = np.full(total_rows, None) # --- Step 8: Assemble final dataframe ----------------------------- # # case_tmin / case_tmax carried forward for _timefilter downstream. df = pl.DataFrame( { pk: repeated_ids, "recordtime": recordtime, "recordtime_end": recordtime_end, "value": values, "value_unit": np.full(total_rows, None), "case_tmin": repeated_tmins, "case_tmax": repeated_tmaxs, }, schema_overrides={ pk: pl.Utf8, "recordtime": pl.Datetime, "recordtime_end": pl.Datetime, "value_unit": pl.Utf8, "case_tmin": pl.Datetime, "case_tmax": pl.Datetime, }, ) # --- Step 9: Cleanup ---------------------------------------------- # df = df.with_columns(pl.selectors.datetime().dt.truncate("1s")) if self.decimals is not None: df = df.with_columns(pl.col("value").round(self.decimals)) return df # ---------------------------------------------------------------------- # # Value generators — one per data type # ---------------------------------------------------------------------- # def _make_value_generator( self, rng: np.random.Generator ) -> Callable[[int], np.ndarray]: """Return a closure that generates *n* synthetic values.""" if self.data_type == "numeric": if self.mean is None: raise ValueError( f"[{self.var_name}] `mean` must be set for data_type='numeric'." ) if self.q1 is None: raise ValueError( f"[{self.var_name}] `q1` must be set for data_type='numeric'." ) if self.q3 is None: raise ValueError( f"[{self.var_name}] `q3` must be set for data_type='numeric'." ) if self.value_min is None: raise ValueError( f"[{self.var_name}] `value_min` must be set for data_type='numeric'." ) if self.value_max is None: raise ValueError( f"[{self.var_name}] `value_max` must be set for data_type='numeric'." ) dist = _fit_truncnorm( mean=self.mean, q1=self.q1, q3=self.q3, value_min=self.value_min, value_max=self.value_max, ) def generate(n: int) -> np.ndarray: return dist.rvs(size=n, random_state=rng) elif self.data_type == "datetime": if self.mean is None: raise ValueError( f"[{self.var_name}] `mean` must be set for data_type='datetime'." ) if self.q1 is None: raise ValueError( f"[{self.var_name}] `q1` must be set for data_type='datetime'." ) if self.q3 is None: raise ValueError( f"[{self.var_name}] `q3` must be set for data_type='datetime'." ) if self.value_min is None: raise ValueError( f"[{self.var_name}] `value_min` must be set for data_type='datetime'." ) if self.value_max is None: raise ValueError( f"[{self.var_name}] `value_max` must be set for data_type='datetime'." ) # Epoch-seconds treated identically to numeric; converted back # to datetime objects afterwards. dist = _fit_truncnorm( mean=self.mean, q1=self.q1, q3=self.q3, value_min=self.value_min, value_max=self.value_max, ) def generate(n: int) -> np.ndarray: epoch_seconds = dist.rvs(size=n, random_state=rng) return np.array( [datetime.fromtimestamp(float(s)) for s in epoch_seconds], dtype="object", ) elif self.data_type == "boolean": def generate(n: int) -> np.ndarray: return rng.random(size=n) < self.true_freq elif self.data_type == "string": cats, weights = self._parse_categories() def generate(n: int) -> np.ndarray: return rng.choice(cats, size=n, p=weights) else: raise ValueError( f"[{self.var_name}] Unsupported data_type: {self.data_type!r}. " "Choose one of: 'numeric', 'boolean', 'string', 'datetime'." ) return generate # ---------------------------------------------------------------------- # # Helpers # ---------------------------------------------------------------------- # def _parse_categories(self) -> tuple[list[str], np.ndarray]: """Parse the ``categories`` argument into (labels, normalised weights).""" if self.categories is None: raise ValueError( f"[{self.var_name}] `categories` must be set for data_type='string'." ) if isinstance(self.categories, list): cats = self.categories weights = np.ones(len(cats)) / len(cats) # uniform else: cats = list(self.categories.keys()) raw = np.array(list(self.categories.values()), dtype=float) weights = raw / raw.sum() # normalise — handles both 0–1 and 0–100 return cats, weights def _empty_dataframe(self, obs_level: ObsLevel) -> pl.DataFrame: """Return a correctly-typed empty DataFrame when no rows are generated.""" return pl.DataFrame( { obs_level.primary_key: pl.Series([], dtype=pl.Utf8), "recordtime": pl.Series([], dtype=pl.Datetime), "recordtime_end": pl.Series([], dtype=pl.Datetime), "value": pl.Series([], dtype=pl.Utf8), "value_unit": pl.Series([], dtype=pl.Utf8), "case_tmin": pl.Series([], dtype=pl.Datetime), "case_tmax": pl.Series([], dtype=pl.Datetime), } )