Source code for corr_vars.sources.cub_hdp.extract

from __future__ import annotations

from datetime import datetime
import os
import pickle
import textwrap

import pandas as pd
import polars as pl

from corr_vars import logger, __
from corr_vars.core.variable import Variable as BaseVariable
from corr_vars.sources.var_loader import MultiSourceVariable
from corr_vars.sources.cub_hdp import helpers
from corr_vars.sources.cub_hdp.mapping.drug import DRUG_MAPPING_PATH
from corr_vars.sources.cub_hdp.mapping.lab import LAB_MAPPING_PATH
import corr_vars.utils as utils
from corr_vars.utils.frames import column_selector

from abc import ABC, abstractmethod
from corr_vars.definitions.typing import ObsLevel, VariableProtocol
from corr_vars.utils.time import TimeWindow
from typing import TYPE_CHECKING, Literal, cast

if TYPE_CHECKING:
    from corr_vars.core import Cohort
    from corr_vars.definitions import (
        TimeAnchorColumn,
        RequirementsIterable,
        CleaningDict,
        VariableCallable,
    )


def VariableLoader(*args, **kwargs) -> BaseVariable:
    """Factory function to create the correct variable type based on the arguments."""
    # Legacy naming error
    if kwargs.get("variable", None):
        kwargs["base_var"] = kwargs.pop("variable")

    try:
        var_type = kwargs.pop("type")

        # Overwrite dynamic attribute
        if var_type in ["native_dynamic", "derived_dynamic", "co6_hierarchy"]:
            kwargs["dynamic"] = True
    except KeyError:
        logger.error("CubHDP VariableLoader requires the 'type' field.")

    if var_type == "complex":
        return ComplexVariable.from_time_window(*args, **kwargs)
    elif var_type == "native_dynamic":
        return NativeDynamic.from_time_window(*args, **kwargs)
    elif var_type == "co6_hierarchy":
        return Co6HierarchyDynamic.from_time_window(*args, **kwargs)
    elif var_type == "derived_dynamic":
        return DerivedDynamic.from_time_window(*args, **kwargs)
    elif var_type == "native_static":
        return NativeStatic.from_time_window(*args, **kwargs)
    elif var_type == "derived_static":
        return DerivedStatic.from_time_window(*args, **kwargs)
    elif var_type == "medication":
        return MedicationDynamic.from_time_window(*args, **kwargs)
    elif var_type == "laboratory":
        return LaboratoryDynamic.from_time_window(*args, **kwargs)
    else:
        raise KeyError(f"Type {var_type} does not exist.")


class CubHDPDynamic(BaseVariable, ABC):
    """Abstract base for all cub_hdp dynamic variables.

    Provides the shared extraction pipeline: optimised obs-level selection,
    caching, attribute building, time-filtering, and cleaning.  Subclasses
    implement :meth:`_extract_from_db` to supply the raw data.

    Args:
        var_name: Variable name.
        dynamic: Whether the variable is dynamic (time-series).
        requires: Dependency variables loaded before extraction.
        tmin: Minimum time anchor.
        tmax: Maximum time anchor.
        py: Optional transformation function applied after raw extraction.
        py_ready_polars: Whether ``py`` accepts and returns polars DataFrames.
        cleaning: Value-range cleaning rules.
        allow_caching: Whether to cache raw data in the cohort's tmpdir.
        screened_obs_level: Default observation level used to scope DB queries
            when the time window is not bounded by hospital-stay bounds.
    """

    def __init__(
        self,
        var_name: str,
        dynamic: bool,
        requires: RequirementsIterable = [],
        tmin: TimeAnchorColumn | None = None,
        tmax: TimeAnchorColumn | None = None,
        py: VariableCallable | None = None,
        py_ready_polars: bool = False,
        cleaning: CleaningDict | None = None,
        allow_caching: bool = True,
        screened_obs_level: Literal["patient", "hospital_stay"] = "hospital_stay",
    ) -> None:
        super().__init__(
            var_name=var_name,
            dynamic=dynamic,
            requires=requires,
            tmin=tmin,
            tmax=tmax,
            py=py,
            py_ready_polars=py_ready_polars,
            cleaning=cleaning,
        )
        self.allow_caching = allow_caching
        self.screened_obs_level: Literal[ObsLevel.HOSPITAL_STAY, ObsLevel.PATIENT] = ObsLevel(screened_obs_level)  # type: ignore[call-arg]

    @abstractmethod
    def extract_from_db(
        self,
        cohort: Cohort,
        id_column: Literal["case_id", "patient_id"],
    ) -> pl.DataFrame:
        """Return raw variable data; called by :meth:`extract` on a cache miss.

        The returned DataFrame must **not** yet have time-window columns or
        cleaning applied — those are handled by the shared pipeline.
        """
        ...

    def _optimise_screened_obs_level(
        self, cohort: Cohort
    ) -> Literal[ObsLevel.HOSPITAL_STAY, ObsLevel.PATIENT]:
        # Ignore screened_obs_level and used fixed value for ObsLevel.PATIENT
        if cohort.obs_level == ObsLevel.PATIENT:
            return ObsLevel.PATIENT

        # Optimise search if bounded by smallest searchable level (hospital stay)
        bounded_by_hospital_stay_level = utils.bounded_by_obs_level(
            time_window=self.time_window,
            cohort=cohort,
            obs_level=ObsLevel.HOSPITAL_STAY,
        )
        return (
            ObsLevel.HOSPITAL_STAY
            if bounded_by_hospital_stay_level
            else self.screened_obs_level
        )

    def _clean_data(self) -> None:
        if self.data is None:
            return

        self.data = self.data.with_columns(
            (pl.selectors.date() | pl.selectors.datetime()).clip(
                datetime.min, datetime.max
            )
        )

        if self.dynamic:
            self.data = utils.build_attributes(self.data)

    def extract(self, cohort: Cohort) -> pl.DataFrame:
        narrowed_obs_level = self._optimise_screened_obs_level(cohort)
        cache_and_join_key = narrowed_obs_level.primary_key

        self.data = self._get_from_cache(cohort, cache_key=cache_and_join_key)

        if self.data is None:
            self._get_required_vars(cohort)
            self.data = self.extract_from_db(cohort, id_column=cache_and_join_key)  # type: ignore[arg-type]
            self._clean_data()
            self._call_var_function(cohort)  # Clean again afterwards

            if self.data is None:
                raise ValueError("Variable extraction resulted in no data.")

            self._clean_data()
            self._write_to_cache(cohort, cache_key=cache_and_join_key)

        if self.dynamic:
            self._add_time_window(cohort, join_key=cache_and_join_key)
            self._timefilter()
            self._apply_cleaning()

        return self.data

    def _get_from_cache(self, cohort: Cohort, cache_key: str) -> pl.DataFrame | None:
        load_path = cohort.tmpdir_manager.create_tmpdir_variable_path(
            var_name=self.var_name, extension=None
        )

        # Before loading, ensure that cacheinfo exists and all case_ids from cohort.obs are present
        if not (
            self.allow_caching
            and os.path.exists(f"{load_path}:{cache_key}.parquet")
            and os.path.exists(f"{load_path}:{cache_key}.cacheinfo")
        ):
            return None

        with open(f"{load_path}:{cache_key}.cacheinfo", "rb") as f:
            cached_ids = pickle.load(f)

        load_ids = set(cohort._obs.select(cache_key).unique().to_series().to_list())

        # Return none if cohort.obs contains uncached load_ids
        if not load_ids.issubset(cached_ids):
            return None

        self.data = pl.read_parquet(f"{load_path}:{cache_key}.parquet")
        logger.info(
            __(
                "Loaded {var_name} from tmpdir ({rows} rows)",
                var_name=self.var_name,
                rows=len(self.data),
            )
        )
        if cache_key in self.data.columns:
            self.data = self.data.filter(pl.col(cache_key).is_in(load_ids))
            logger.info(
                __("Found {rows} rows matching cohort.obs", rows=len(self.data))
            )

        return self.data

    def _write_to_cache(self, cohort: Cohort, cache_key: str) -> None:
        if not self.allow_caching or self.data is None:
            return

        save_path = cohort.tmpdir_manager.create_tmpdir_variable_path(
            var_name=f"{self.var_name}:{cache_key}", extension=None
        )
        self.data.write_parquet(f"{save_path}.parquet")
        saved_ids = cohort._obs.select(cache_key).unique().to_series().to_list()

        if os.path.exists(f"{save_path}.cacheinfo"):
            os.remove(f"{save_path}.cacheinfo")

        # Write cacheinfo
        with open(f"{save_path}.cacheinfo", "wb") as f:
            pickle.dump(saved_ids, f)

        logger.info(
            __(
                "Wrote {var_name} to tmpdir ({rows} rows)",
                var_name=self.var_name,
                rows=len(self.data),
            )
        )


class NativeExtractor:
    """Raw SQL extractor satisfying VariableProtocol.

    Fetches data by running a parameterised query against an HDP table.  No
    caching, time-filtering, or cleaning — those are handled by the wrapping
    :class:`NativeDynamic` pipeline.

    Args:
        var_name: Variable name.
        table: HDP table name.
        where: SQL WHERE clause appended to the generated query.
        columns: Column alias/cast mapping for the query.
        time_window: Time window (stored to satisfy the protocol).
    """

    def __init__(
        self,
        var_name: str,
        *,
        time_window: TimeWindow | None = None,
        table: str | None = None,
        where: str | None = None,
        columns: dict[str, str | dict[str, str]] | None = None,
        value_dtype: str | None = None,
    ) -> None:
        self.var_name = var_name
        self.dynamic = True
        self.time_window = time_window or TimeWindow()
        self.table = table
        self.where = where
        self.columns = columns
        self.value_dtype: str | None = value_dtype

    def extract(
        self,
        cohort: Cohort,
        id_column: Literal["case_id", "patient_id"],
    ) -> pl.DataFrame:
        with helpers.temporary_table(cohort, id_column=id_column) as temp_tbl:
            logger.debug(
                __(
                    "Building query for {var_name}",
                    var_name=self.var_name,
                )
            )

            if not self.table:
                raise ValueError(f"Table not specified for variable {self.var_name}")

            dbs = helpers.get_cub_hdp_database(cohort)
            original_id_column = helpers.get_table_info(self.table)[id_column]
            where_parts = [
                helpers.sql_tmp_filter(
                    original_id_column, id_column, dbs["workbench"], temp_tbl
                )
            ]
            if self.where:
                where_parts.append(self.where)

            query, alias_mapping, datetime_cols = helpers.build_native_extractor_query(
                database=dbs["prepared"],
                table=self.table,
                id_column=id_column,
                where="\n            AND ".join(where_parts),
                columns=self.columns,
                value_dtype=self.value_dtype,
            )
            conn = helpers.reinit_conn(cohort)

            result_df: pl.LazyFrame = conn.lf_read_sql(
                query=query,
                datetime_cols=datetime_cols,
                keep_cols=alias_mapping.keys(),
            )
            logger.info(
                __(
                    "Extracted {rows} rows",
                    rows=lambda: result_df.select(pl.len()).collect().item(),
                )
            )

        return result_df.collect()


[docs] class NativeDynamic(CubHDPDynamic): """SQL-backed dynamic variable with the full :class:`CubHDPDynamic` pipeline. Wraps a :class:`NativeExtractor` to provide caching, time-filtering, cleaning, and optional ``py`` post-processing. Args: table: HDP table name. where: SQL WHERE clause appended to the generated query. value_dtype: Optional polars dtype name applied to the ``value`` column after extraction (e.g. ``"Float64"``). columns: Column alias/cast mapping for the query. All other args are forwarded to :class:`CubHDPDynamic`. Note: Only use this class for defining custom variables. For CORR-specified variables, use ``cohort.add_variable()`` directly. """ def __init__( self, var_name: str, dynamic: bool, requires: RequirementsIterable = [], tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, allow_caching: bool = True, screened_obs_level: Literal["patient", "hospital_stay"] = "hospital_stay", table: str | None = None, where: str | None = None, value_dtype: str | None = None, columns: dict[str, str | dict[str, str]] | None = None, ) -> None: super().__init__( var_name=var_name, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, allow_caching=allow_caching, screened_obs_level=screened_obs_level, ) # self.value_dtype = value_dtype self.extractor = NativeExtractor( var_name=var_name, time_window=self.time_window, table=table, where=where, columns=columns, value_dtype=value_dtype, )
[docs] def extract_from_db( self, cohort: Cohort, id_column: Literal["case_id", "patient_id"], ) -> pl.DataFrame: raw = self.extractor.extract(cohort, id_column) # if self.value_dtype is not None: # dtype = getattr(pl, self.value_dtype) # raw = raw.with_columns(pl.col("value").cast(dtype, strict=False)) return raw
class ComplexVariable(CubHDPDynamic): """Variable extracted entirely by a custom ``py()`` function. No database query is issued; the variable function receives ``required_vars`` and must return a ``pl.DataFrame``. Time-filtering is applied only when the function adds ``case_tmin``/``case_tmax`` columns to the result. Args: var_name: Name of the variable. dynamic: Whether the variable is dynamic. requires: Variables loaded and passed to the function. tmin: Minimum time anchor. tmax: Maximum time anchor. py: Variable function that produces the data. py_ready_polars: Whether ``py`` accepts and returns polars DataFrames. cleaning: Dictionary of cleaning rules. allow_caching: Whether to allow caching. screened_obs_level: Observation level used for scoping the cache key. """ def __init__( self, var_name: str, dynamic: bool, requires: RequirementsIterable = [], tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, allow_caching: bool = True, screened_obs_level: Literal["patient", "hospital_stay"] = "hospital_stay", ) -> None: super().__init__( var_name=var_name, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, allow_caching=allow_caching, screened_obs_level=screened_obs_level, ) self._timefilter_always = False def extract_from_db( self, cohort: Cohort, id_column: Literal["case_id", "patient_id"], ) -> pl.DataFrame: raise NotImplementedError( f"ComplexVariable '{self.var_name}' does not extract from the database. " "Data must be produced inside the variable function." ) def extract(self, cohort: Cohort) -> pl.DataFrame: narrowed_obs_level = self._optimise_screened_obs_level(cohort) cache_and_join_key = narrowed_obs_level.primary_key data = self._get_from_cache(cohort, cache_key=cache_and_join_key) if ( data is not None and "case_tmin" in data.columns and "case_tmax" in data.columns ): # For complex variables that have precomputed case_tmin and case_tmax # The timefilter needs to be reapplied - we will set timefilter_always to True # For this, we need to drop the case_tmin, case_tmax non cache_key primary keys drop_cols = ["case_tmin", "case_tmax"] + ( [cohort.primary_key] if cohort.primary_key != cache_and_join_key else [] ) data = data.drop(drop_cols, strict=False) if ( "case_id" in data.columns ): # TODO Change later after ensuring that cache_key is present in all variables self._timefilter_always = True else: data = None self.data = data if self.data is None: self._get_required_vars(cohort) self._clean_data() self._call_var_function(cohort) # Clean again afterwards if self.data is None: raise ValueError("Variable extraction resulted in no data.") self._clean_data() self._write_to_cache(cohort, cache_key=cache_and_join_key) if self.dynamic: # _get_from_cache may set _timefilter_always=True when reloading # cached data that previously had case_tmin/case_tmax if self._timefilter_always: self._add_time_window(cohort, join_key=cache_and_join_key) self._timefilter() self._apply_cleaning() return self.data
[docs] class NativeStatic(BaseVariable): """Aggregated variables represent simple aggregations of dynamic variables. Args: var_name (str): Name of the variable. select (str): Select clause specifying aggregation function and columns. base_var (str): Name of the base variable (must be a native_dynamic variable). where (str, optional): Optional WHERE clause (in format for polars). tmin (str, optional): Minimum time for the extraction. tmax (str, optional): Maximum time for the extraction. The select argument supports several aggregation functions: - ``!first [columns]``: Returns the first row within this case >>> "!first value" # Single column >>> "!first value, recordtime" # Multiple columns - ``!last [columns]``: Returns the last row within this case >>> "!last value" >>> "!last value, recordtime" - ``!any``: Returns True if any value exists >>> "!any" >>> "!any value" - ``!sum [column]``: Calculates sum of values >>> "!sum value" - ``!closest(to_column, timedelta, plusminus) [columns]``: Selects value closest to specified column Args: to_column: Column to compare "recordtime" against timedelta: Time to add to "to_column" for comparison plusminus: Allowed time mismatch (can specify different before/after with space) >>> "!closest(hospital_admission) value, recordtime" # Closest to admission >>> "!closest(hospital_admission, 0, 2h 3h) value" # 2h before to 3h after >>> "!closest(first_intubation_dtime, 6h, 2h) value" # 6h after intubation ±2h - ``!mean [column]``: Calculates mean value >>> "!mean value" - ``!median [column]``: Calculates median value >>> "!median value" - ``!perc(quantile) [column]``: Calculates specified percentile >>> "!perc(75) value" # 75th percentile The where argument supports SQL-style boolean expressions. These are evaluated in the context of the base variable by column_selector(). Where also supports magic commands (starting with !) to filter the data. Supported commands are: - ``!isin(column, [values])``: Filters rows where the value in column is in values - ``!startswith(column, [values])``: Filters rows where the value in column starts with any of the values - ``!endswith(column, [values])``: Filters rows where the value in column ends with any of the values """ base_var: tuple[str, TimeWindow] | VariableProtocol | MultiSourceVariable def __init__( self, var_name: str, select: str, base_var: str | VariableProtocol | MultiSourceVariable, where: str | None = None, tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, cleaning: CleaningDict | None = None, ) -> None: super().__init__( var_name, dynamic=False, tmin=tmin, tmax=tmax, cleaning=cleaning, ) logger.debug( __( "NativeStatic: SELECT {select} FROM {base_var} WHERE {where}", select=select, base_var=( base_var if isinstance(base_var, str) else base_var.__class__.__name__ ), where=where, ) ) self.select = select self.where = where if isinstance(base_var, str): _var = (base_var, self.time_window) else: _var = base_var _var.time_window = self.time_window self.base_var = _var self.agg_func, self.agg_params, self.select_cols = helpers.parse_select( self.select ) if len(self.select_cols) == 0: self.select_cols = ["value"]
[docs] def extract(self, cohort: Cohort) -> pl.DataFrame: """Extract the variable. You do not need to call this yourself, as it is called internally when you add the variable to a cohort. However, you may call it directly to obtain variable data independently of the cohort. You still need a cohort object for case ids and other metadata. Args: cohort: Cohort object. Returns: Extracted variable. After extraction, you may also access the data as ``Variable.data``. Examples: >>> var = NativeStatic( ... var_name="first_sodium_recordtime", ... select="!first recordtime", ... base_var="blood_sodium", ... tmin="hospital_admission" ... ) >>> var.extract(cohort) # With var.extract(), the data will not be added to the cohort. >>> var.data # You can access the data as a polars dataframe. """ include_sources = [self.guessed_source] if self.guessed_source else None _var = cohort.load_variable(self.base_var, include_sources=include_sources) self.base_data = _var.extract(cohort=cohort) self.data = self._extract_aggregation(cohort) self._apply_cleaning() self._call_var_function(cohort) return self.data
def _extract_aggregation(self, cohort: Cohort) -> pl.DataFrame: """Extract static variable using the base variable Args: cohort (Cohort): Cohort object. Returns: pl.DataFrame: Extracted variable. """ base_data = self.base_data.lazy() obs = cohort._obs.lazy() # Sort by recordtime base_data = base_data.sort("recordtime") # Apply where clause filtering if self.where: if self.where.startswith("!"): where_func, where_params, _ = helpers.parse_select(self.where) match where_func: case "isin": base_data = base_data.filter( column_selector(where_params[0]).is_in(where_params[1]) ) case "startswith": # Create OR expression for multiple prefixes if len(where_params[1]) == 1: base_data = base_data.filter( column_selector(where_params[0]).str.starts_with( where_params[1][0] ) ) else: starts_expr = column_selector( where_params[0] ).str.starts_with(where_params[1][0]) for prefix in where_params[1][1:]: starts_expr = starts_expr | column_selector( where_params[0] ).str.starts_with(prefix) base_data = base_data.filter(starts_expr) case "endswith": # Create OR expression for multiple suffixes if len(where_params[1]) == 1: base_data = base_data.filter( column_selector(where_params[0]).str.ends_with( where_params[1][0] ) ) else: ends_expr = column_selector(where_params[0]).str.ends_with( where_params[1][0] ) for suffix in where_params[1][1:]: ends_expr = ends_expr | column_selector( where_params[0] ).str.ends_with(suffix) base_data = base_data.filter(ends_expr) case _: raise ValueError( f"Unsupported where function: {where_func}. Supported functions are: isin, startswith, endswith" ) else: # Convert pandas-style where clause to polars expression polars_where = pl.sql_expr(self.where) base_data = base_data.filter(polars_where) # Parse time arguments obs = utils.add_time_window_expr( obs, time_window=self.time_window, aliases=("tmin", "tmax") ) # Join with observation times base_data = base_data.join( obs.select([cohort.primary_key, "tmin", "tmax"]), on=cohort.primary_key, how="left", ) logger.debug( __( "Base data shape (before filtering): {rows}", rows=utils.row_length(base_data), ) ) base_data_time = base_data.filter( pl.col("recordtime").is_between("tmin", "tmax") ) logger.debug( __( "Base data shape (after filtering tmin and tmax): {rows}", rows=utils.row_length(base_data), ) ) params = self.agg_params # Perform aggregation using Polars syntax match self.agg_func.lower(): case "first": res = base_data_time.group_by(cohort.primary_key).agg( column_selector(col).first() for col in self.select_cols ) case "last": res = base_data_time.group_by(cohort.primary_key).agg( column_selector(col).last() for col in self.select_cols ) case "mean": res = base_data_time.group_by(cohort.primary_key).agg( column_selector(col).mean() for col in self.select_cols ) case "median": res = base_data_time.group_by(cohort.primary_key).agg( column_selector(col).median() for col in self.select_cols ) case "min": res = base_data_time.group_by(cohort.primary_key).agg( column_selector(col).min() for col in self.select_cols ) case "max": res = base_data_time.group_by(cohort.primary_key).agg( column_selector(col).max() for col in self.select_cols ) case "std": res = base_data_time.group_by(cohort.primary_key).agg( column_selector(col).std() for col in self.select_cols ) case "perc": quantile = float(params[0]) / 100 res = base_data_time.group_by(cohort.primary_key).agg( column_selector(col).quantile(quantile) for col in self.select_cols ) case "sum": res = base_data_time.group_by(cohort.primary_key).agg( column_selector(col).sum() for col in self.select_cols ) case "count": res = base_data_time.group_by(cohort.primary_key).agg( pl.len().alias(self.select_cols[0]) ) case "any": res = base_data_time.group_by(cohort.primary_key).agg( column_selector(col).is_not_null().any() for col in self.select_cols ) # Right join with all observations to include cases with no data res = ( obs.select(cohort.primary_key) .unique() .join(res, on=cohort.primary_key, how="left") .with_columns( column_selector(col).fill_null(False) for col in self.select_cols ) ) case "sum_interval": # Use unfiltered base_data instead of base_data_time here if "recordtime_end" not in base_data.columns: raise ValueError( "sum_interval requires 'recordtime_end' column in the base variable data" ) interval_data = ( base_data.with_columns( [ # Ensure recordtime_end is not null, use recordtime as fallback pl.when(pl.col("recordtime_end").is_null()) .then(pl.col("recordtime")) .otherwise(pl.col("recordtime_end")) .alias("recordtime_end_clean"), ] ) .with_columns( [ # Ensure interval is valid (end >= start), swap if necessary pl.when( pl.col("recordtime_end_clean") < pl.col("recordtime") ) .then( pl.col("recordtime") ) # Use recordtime as both start and end if invalid .otherwise(pl.col("recordtime_end_clean")) .alias("recordtime_end_clean"), ] ) .with_columns( [ # Calculate overlap start (max of interval start and tmin) pl.max_horizontal(["recordtime", "tmin"]).alias( "overlap_start" ), # Calculate overlap end (min of interval end and tmax) pl.min_horizontal(["recordtime_end_clean", "tmax"]).alias( "overlap_end" ), ] ) ) # Filter to only intervals that have some overlap interval_data = interval_data.filter( pl.col("overlap_start") <= pl.col("overlap_end") ) # Calculate overlap duration and total interval duration interval_data = interval_data.with_columns( [ # Overlap duration in seconds (pl.col("overlap_end") - pl.col("overlap_start")) .dt.total_seconds() .alias("overlap_seconds"), # Total interval duration in seconds (pl.col("recordtime_end_clean") - pl.col("recordtime")) .dt.total_seconds() .alias("interval_seconds"), ] ) # Calculate proportional values for each selected column proportional_exprs = [] for col in self.select_cols: # Handle division by zero: if interval is instantaneous (0 duration), use full value proportional_expr = ( pl.when(pl.col("interval_seconds") <= 0) .then(pl.col(col)) # Full value for instantaneous intervals .otherwise( pl.col(col) * pl.col("overlap_seconds") / pl.col("interval_seconds") ) .alias(f"{col}_proportional") ) proportional_exprs.append(proportional_expr) interval_data = interval_data.with_columns(proportional_exprs) # Sum the proportional values by primary key sum_exprs = [ pl.col(f"{col}_proportional").sum().alias(col) for col in self.select_cols ] res = interval_data.group_by(cohort.primary_key).agg(sum_exprs) # Ensure all cases are included, even those with no overlapping intervals # Join with all observations to include cases with zero sum res = ( obs.select(cohort.primary_key) .unique() .join(res, on=cohort.primary_key, how="left") .with_columns( [pl.col(col).fill_null(0.0) for col in self.select_cols] ) ) logger.debug( __( "sum_interval processed {rows} interval records", rows=lambda: interval_data.select(pl.len()).collect().item(), ) ) case "closest": to_col = params[0] tdelta = params[1] if len(params) > 1 else None plusminus = params[2] if len(params) > 2 else "52w" if " " in plusminus: pm_before, pm_after = plusminus.split(" ") else: pm_before = pm_after = plusminus # Convert time deltas to total seconds for tolerance pm_before_secs = pd.to_timedelta(pm_before).total_seconds() pm_after_secs = pd.to_timedelta(pm_after).total_seconds() tdelta_secs = pd.to_timedelta(tdelta).total_seconds() if tdelta else 0.0 # Create target times dataframe target_times = obs.select( cohort.primary_key, column_selector(to_col) .add(pl.duration(seconds=int(tdelta_secs))) .alias("target_time"), ) # Use merge_asof for efficient time-based matching # For asymmetric tolerances, we need to handle this differently if pm_before_secs == pm_after_secs: # Symmetric tolerance - use merge_asof directly res = target_times.join_asof( base_data_time.select( [cohort.primary_key, "recordtime"] + self.select_cols ), left_on="target_time", right_on="recordtime", by=cohort.primary_key, tolerance=str(int(pm_before_secs)) + "s", strategy="nearest", ).drop("target_time") else: # Asymmetric tolerance - filter then use merge_asof # Create time bounds target_times = target_times.with_columns( pl.col("target_time") .sub(pl.duration(seconds=int(pm_before_secs))) .alias("time_min"), pl.col("target_time") .add(pl.duration(seconds=int(pm_after_secs))) .alias("time_max"), ) # Filter data within the asymmetric window using a join filtered_data = base_data_time.join( target_times.select( cohort.primary_key, "time_min", "time_max", "target_time" ), on=cohort.primary_key, how="inner", ).filter( pl.col("recordtime").ge(pl.col("time_min")) & pl.col("recordtime").le(pl.col("time_max")) ) # Now use merge_asof to find the closest match within the filtered data res = target_times.join_asof( filtered_data.select( [cohort.primary_key, "recordtime"] + self.select_cols ), left_on="target_time", right_on="recordtime", by=cohort.primary_key, strategy="nearest", ).drop("target_time", "time_min", "time_max") case _: raise ValueError(f"Unsupported aggregation function: {self.agg_func}") logger.info( __( "Extracted {var_name}.\nColumns: {columns}", var_name=self.var_name, columns=lambda: ", ".join(utils.columns(res)), ) ) res = res.select([cohort.primary_key] + self.select_cols) # Apply column naming # TODO: Remove if this is already handled by Cohort._save_variable if len(self.select_cols) > 1: rename_map = {col: f"{self.var_name}_{col}" for col in self.select_cols} res = res.rename(rename_map) else: res = res.rename({self.select_cols[0]: self.var_name}) return res.collect()
[docs] class DerivedStatic(BaseVariable): """DerivedStatic: These variables are derivations on existing columns in the cohort.obs dataframe based on the expression argument. Args: var_name: Name of the variable. requires: List of required variables. expression: Expression to extract the variable. tmin: Minimum time for the extraction. tmax: Maximum time for the extraction. Note that DerivedStatic variables are executed on the cohort.obs dataframe and must reference existing columns in cohort.obs. For DerivedStatic variables, you may either provide an SQL-Like expression (which will be parsed by `column_selector`) or a custom function in variables.py. Use expressions where possible, but custom functions if you require more complex logic. Examples: >>> DerivedStatic( ... var_name="inhospital_death", ... requires=["hospital_discharge", "death_timestamp"], ... expression="hospital_discharge <= death_timestamp" ... ) >>> DerivedStatic( ... var_name="any_va_ecmo_icu", ... requires=["ecmo_va_icu_ops", "ecmo_va_icu"], ... expression=(ecmo_va_icu_ops | ecmo_va_icu) ... ) """ def __init__( self, var_name, requires: RequirementsIterable = [], expression: str | None = None, tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, dynamic: bool = False, cleaning: CleaningDict | None = None, ) -> None: assert ( dynamic is False ), "DerivedStatic cannot be dynamic, please verify the configuration." super().__init__( var_name=var_name, dynamic=False, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, ) self.expression = expression self.computed_expression = False
[docs] def extract(self, cohort: Cohort) -> pl.DataFrame: self._get_required_vars(cohort) self.data = self._compute_expression(cohort) called_var_function = self._call_var_function(cohort) assert ( self.computed_expression or called_var_function ), "No expression or variable function found." assert ( self.var_name in self.data.columns ), f"DerivedStatic variable {self.var_name} not found in extracted data. ({self.data.columns})" # Get all columns that are either the primary key, the variable name itself, or start with the variable name cols_to_keep = [ col for col in self.data.columns if col == cohort.primary_key or col == self.var_name or (isinstance(col, str) and col.startswith(f"{self.var_name}_")) ] self.data = self.data.select(cols_to_keep) self._apply_cleaning() return self.data
def _compute_expression(self, cohort: Cohort): """Compute the expression using the required variables. Args: cohort (Cohort): Cohort object. Returns: pl.DataFrame: Extracted variable. (Returns the original obs dataframe if no expression is provided.) """ obs = cohort._obs.clone() for var in self.required_vars.values(): if not var.dynamic and var.data is not None: obs = obs.join( var.data.select([cohort.primary_key, var.var_name]), on=cohort.primary_key, how="left", ) if not self.expression: return obs else: expr = pl.sql_expr(self.expression) obs = obs.with_columns(expr.alias(self.var_name)) self.computed_expression = True return obs
[docs] class DerivedDynamic(BaseVariable): """Derived dynamic variables are extracted using a custom function. Args: var_name: Name of the variable. requires: List of required variables. cleaning: Cleaning parameters ({column_name: {low: int, high: int}}) tmin: Minimum time for the extraction. tmax: Maximum time for the extraction. py: Custom function. py_ready_polars: Whether the custom function is already prepared for polars. Default is False (input and output are pandas dataframes). Examples: >>> def var_func(var, cohort): ... # Note that this simplified example only works if blood_pao2_arterial and vent_fio2 are of the same length (which is proabably not the case). ... return var.with_columns( ... (var.required_vars["blood_pao2_arterial"] / var.required_vars["vent_fio2"]).alias("pf_ratio") ... ) >>> DerivedDynamic( ... var_name="pf_ratio", ... requires=["blood_pao2_arterial", "vent_fio2"], ... py=var_func, ... py_ready_polars=True) """ def __init__( self, var_name, requires: RequirementsIterable, cleaning: CleaningDict | None = None, tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, dynamic: bool = True, ) -> None: assert ( dynamic is True ), "DerivedDynamic must be dynamic, please verify the configuration." super().__init__( var_name, dynamic=True, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, )
[docs] def extract(self, cohort: Cohort) -> pl.DataFrame: """Extract the variable. Args: cohort (Cohort): Cohort object. Returns: pl.DataFrame: Extracted variable. """ self._get_required_vars(cohort) self._call_var_function(cohort) if self.data is None: raise ValueError("Variable extraction resulted in no data.") self._apply_cleaning() return self.data
class Co6HierarchyExtractor: """Raw CoPRA 6 hierarchy extractor satisfying VariableProtocol. Fetches data by joining child nodes under a parent in the CoPRA 6 hierarchy table. No caching, time-filtering, or cleaning — those are handled by the wrapping :class:`Co6HierarchyDynamic` pipeline. Args: var_name: Variable name. parent_name: CoPRA 6 parent node name (e.g. ``"Score_SOFA"``). suffixes: Maps ``"_Suffix"`` → output column name. Must include a suffix that resolves to ``"recordtime"``. table: Source table. Defaults to ``"it_copra6_hierarchy_v2"``. not_strict: Suffixes to LEFT JOIN instead of INNER JOIN. time_window: Time window used to scope the SQL query. """ def __init__( self, var_name: str, *, parent_name: str, suffixes: dict[str, str], not_strict: list[str] | None = None, table: str = "it_copra6_hierarchy_v2", time_window: TimeWindow | None = None, ) -> None: self.var_name = var_name self.dynamic = True self.time_window = time_window or TimeWindow() self.parent_name = parent_name self.suffixes = suffixes self.not_strict: list[str] = not_strict or [] self.table = table def extract( self, cohort: Cohort, id_column: Literal["case_id", "patient_id"] = "case_id", ) -> pl.DataFrame: with helpers.temporary_table(cohort, id_column=id_column) as temp_tbl: logger.debug( __( "Building query for {var_name}", var_name=self.var_name, ) ) # Build query (TODO: Table alias for original_id_column not robust) dbs = helpers.get_cub_hdp_database(cohort) original_id_column = "p." + helpers.get_table_info(self.table)[id_column] where = helpers.sql_tmp_filter( original_id_column, id_column, dbs["workbench"], temp_tbl ) ( query, table_info, datetime_cols, ) = helpers.build_co6hierarchy_extractor_query( database=dbs["prepared"], table=self.table, id_column=id_column, where=where, parent_name=self.parent_name, suffixes=self.suffixes, not_strict=self.not_strict, ) # Execute query conn = helpers.reinit_conn(cohort) result_df: pl.LazyFrame = conn.lf_read_sql( query=query, datetime_cols=datetime_cols, keep_cols=table_info.values(), ) # Note: Timestamps that are stored in c_value are UTC, while everything else (c_var_timestamp) is local time. # We need to convert the c_value timestamps to local time. result_df = result_df.with_columns( pl.col("recordtime") .dt.replace_time_zone("UTC") .dt.convert_time_zone("Europe/Berlin") .dt.cast_time_unit("us") .dt.replace_time_zone(None) ) logger.info( __( "Extracted {rows} rows", rows=lambda: result_df.select(pl.len()).collect().item(), ) ) return result_df.collect() class Co6HierarchyDynamic(CubHDPDynamic): """CoPRA 6 hierarchy variable with the full :class:`CubHDPDynamic` pipeline. Wraps a :class:`Co6HierarchyExtractor` to provide caching, time-filtering, cleaning, and optional ``py`` post-processing. Args: parent_name: CoPRA 6 parent node name. suffixes: Maps ``"_Suffix"`` → output column name. table: Source table. Defaults to ``"it_copra6_hierarchy_v2"``. not_strict: Suffixes to LEFT JOIN instead of INNER JOIN. value_dtype: Optional polars dtype name applied to the ``value`` column after extraction (e.g. ``"Float64"``, ``"Int64"``). All other args are forwarded to :class:`CubHDPDynamic`. """ def __init__( self, var_name: str, dynamic: bool = True, requires: RequirementsIterable = [], tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, allow_caching: bool = True, screened_obs_level: Literal["patient", "hospital_stay"] = "hospital_stay", *, parent_name: str, suffixes: dict[str, str], table: str = "it_copra6_hierarchy_v2", not_strict: list[str] | None = None, value_dtype: str | None = None, ) -> None: super().__init__( var_name=var_name, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, allow_caching=allow_caching, screened_obs_level=screened_obs_level, ) self.value_dtype = value_dtype self.extractor = Co6HierarchyExtractor( var_name=var_name, parent_name=parent_name, suffixes=suffixes, table=table, not_strict=not_strict, time_window=self.time_window, ) def extract_from_db( self, cohort: Cohort, id_column: Literal["case_id", "patient_id"] ) -> pl.DataFrame: raw = self.extractor.extract(cohort, id_column) # TODO: Find better way for the workaround to have the primary_key in the df before the variable function raw = raw.join(cohort._obs.select(cohort.primary_key, id_column), on=id_column) if self.value_dtype is not None: dtype = getattr(pl, self.value_dtype) raw = raw.with_columns(pl.col("value").cast(dtype, strict=False)) return raw class MedicationExtractor: """Raw medication data extractor satisfying VariableProtocol. Fetches all drug administration rows for the requested canonical drugs or ATC codes. No caching, no time-filtering, no cleaning — those are handled by the wrapping :class:`MedicationDynamic` pipeline. Args: var_name: Variable name (used as the canonical drug name when neither ``drug`` nor ``atc`` is given). drug: One or more canonical drug names, optionally qualified with an ingredient suffix (``"canonical:ingredient"``). atc: One or more ATC prefix strings used to look up canonical names. """ def __init__( self, var_name: str, *, drug: str | list[str] | None = None, atc: str | list[str] | None = None, ) -> None: self.var_name = var_name self.dynamic = True self.time_window = TimeWindow() self.drugs: list[str] = [drug] if isinstance(drug, str) else drug or [] self.atc: list[str] = [atc] if isinstance(atc, str) else atc or [] def extract( self, cohort: Cohort, id_column: Literal["case_id", "patient_id"] = "case_id", ) -> pl.DataFrame: canonical_targets = ( self._canonical_targets_from_drugs() + self._canonical_targets_from_atc() ) or [(self.var_name, None)] drug_dfs = [ get_drug( cohort=cohort, df_mapping=pl.scan_parquet(DRUG_MAPPING_PATH), canonical=canonical, target_ingredient=target_ingredient, id_column=id_column, ) for canonical, target_ingredient in canonical_targets ] return pl.concat(drug_dfs).collect() def _canonical_targets_from_drugs(self) -> list[tuple[str, str | None]]: return [ (part[0], part[2] or None) for part in [d.partition(":") for d in self.drugs] ] def _canonical_targets_from_atc(self) -> list[tuple[str, str | None]]: if not self.atc: return [] canonicals = cast( list[str], pl.read_parquet(DRUG_MAPPING_PATH) .filter( pl.any_horizontal( pl.col("drug_atc").str.starts_with(a) for a in self.atc ) )["drug_canonical"] .unique() .to_list(), ) return [(canonical, None) for canonical in canonicals] class MedicationDynamic(CubHDPDynamic): """Medication variable with the full :class:`CubHDPDynamic` pipeline. Wraps a :class:`MedicationExtractor` so that caching, time-filtering, cleaning, and optional ``py`` transformations all apply automatically. Args: var_name: Variable name. drug: One or more canonical drug names (``"canonical"`` or ``"canonical:ingredient"``). atc: One or more ATC prefix strings. requires: Dependencies loaded before the optional ``py`` function. py: Optional transformation applied after raw extraction. py_ready_polars: Whether ``py`` accepts and returns polars DataFrames. cleaning: Value-range cleaning rules. allow_caching: Whether to cache raw data in the cohort's tmpdir. """ def __init__( self, var_name: str, dynamic: bool = True, requires: RequirementsIterable = [], tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, allow_caching: bool = True, *, drug: str | list[str] | None = None, atc: str | list[str] | None = None, ) -> None: super().__init__( var_name=var_name, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, allow_caching=allow_caching, ) self.extractor = MedicationExtractor(var_name=var_name, drug=drug, atc=atc) def extract_from_db( self, cohort: Cohort, id_column: Literal["case_id", "patient_id"] ) -> pl.DataFrame: return self.extractor.extract(cohort, id_column) class LaboratoryExtractor: """Raw laboratory data extractor satisfying VariableProtocol. Fetches all measurement rows for the requested lab concepts. No caching, no time-filtering, no cleaning — those are handled by the wrapping :class:`LaboratoryDynamic` pipeline. Args: var_name: Variable name (used as the lab concept when ``lab`` is not given). lab: One or more lab concept names to fetch. """ def __init__( self, var_name: str, *, lab: str | list[str] | None = None, ) -> None: self.var_name = var_name self.dynamic = True self.time_window = TimeWindow() self.labs: list[str] = [lab] if isinstance(lab, str) else lab or [] def extract( self, cohort: Cohort, id_column: Literal["case_id", "patient_id"] = "case_id", ) -> pl.DataFrame: lab_concepts = self.labs or [self.var_name] lab_dfs = [ get_lab( cohort=cohort, df_mapping=pl.scan_parquet(LAB_MAPPING_PATH), lab_concept=concept, id_column=id_column, ) for concept in lab_concepts ] return pl.concat(lab_dfs).collect() class LaboratoryDynamic(CubHDPDynamic): """Laboratory variable with the full :class:`CubHDPDynamic` pipeline. Wraps a :class:`LaboratoryExtractor` so that caching, time-filtering, cleaning, and optional ``py`` transformations all apply automatically. Args: var_name: Variable name. lab: One or more lab concept names to fetch. requires: Dependencies loaded before the optional ``py`` function. py: Optional transformation applied after raw extraction. py_ready_polars: Whether ``py`` accepts and returns polars DataFrames. cleaning: Value-range cleaning rules. allow_caching: Whether to cache raw data in the cohort's tmpdir. """ def __init__( self, var_name: str, dynamic: bool = True, requires: RequirementsIterable = [], tmin: TimeAnchorColumn | None = None, tmax: TimeAnchorColumn | None = None, py: VariableCallable | None = None, py_ready_polars: bool = False, cleaning: CleaningDict | None = None, allow_caching: bool = True, *, lab: str | list[str] | None = None, ) -> None: super().__init__( var_name=var_name, dynamic=dynamic, requires=requires, tmin=tmin, tmax=tmax, py=py, py_ready_polars=py_ready_polars, cleaning=cleaning, allow_caching=allow_caching, ) self.extractor = LaboratoryExtractor(var_name=var_name, lab=lab) def extract_from_db( self, cohort: Cohort, id_column: Literal["case_id", "patient_id"] ) -> pl.DataFrame: return self.extractor.extract(cohort, id_column) def _safe_f64(col) -> pl.Expr: """Coerce Impala string-numbers ('15.000', '0,5', '0E-12') to Float64.""" return ( pl.col(col) .cast(pl.Utf8, strict=False) .str.strip_chars() .str.replace(",", ".") .str.replace("0E-12", "0") .cast(pl.Float64, strict=False) ) DRUG_APPS_SQL = """\ SELECT case_id, source_type, application_id, mixture_id, config_drug_id, drug_name, active_agent_id, ingredient_id, ingredient_name, application_start AS recordtime, application_end AS recordtime_end, app_quantity, app_volume, app_rate_volume, time_unit, is_weight_related, is_on_demand, mixture_name, mixture_volume, dose_clinical, dose_clinical_unit, concentration, concentration_unit, rate, rate_unit, application_location_name AS application_route, substance_group FROM {db_apps}.{tbl_apps} WHERE config_drug_id IN ({cdids}) AND {tmp_filter} """.rstrip() def get_filtered_drugs( cohort: Cohort, cdids: list[str], id_column: Literal["case_id", "patient_id"] = "case_id", ) -> pl.LazyFrame: conn = helpers.reinit_conn(cohort) dbs = helpers.get_cub_hdp_database(cohort) with helpers.temporary_table(cohort, id_column=id_column) as temp_tbl: if id_column == "patient_id": # drug_apps has no patient_id — filter via subquery through it_ishmed_fall tmp_filter = ( f"case_id IN (SELECT c_falnr FROM {dbs['prepared']}.it_ishmed_fall " f"WHERE {helpers.sql_tmp_filter('c_patnr', 'patient_id', dbs['workbench'], temp_tbl)})" ) else: tmp_filter = helpers.sql_tmp_filter( "case_id", "case_id", dbs["workbench"], temp_tbl ) apps = conn.lf_read_sql( DRUG_APPS_SQL.format( db_apps=dbs["workbench"], tbl_apps="drug_apps", tmp_filter=tmp_filter, cdids=helpers.to_sql_IN_list(cdids), ), datetime_cols=["recordtime", "recordtime_end"], ) if id_column == "patient_id": fall = cohort._obs.lazy().select("patient_id", "case_id") apps = apps.join(fall, on="case_id").drop("case_id") return apps def get_drug( cohort: Cohort, df_mapping: pl.LazyFrame, canonical: str, *, target_ingredient: str | None = None, id_column: Literal["case_id", "patient_id"] = "case_id", ) -> pl.LazyFrame: """ Return all administrations of a canonical drug as a clean, unit-unified df. Output columns: identifier, recordtime, recordtime_end, canonical_drug_name, atc_code, application_route, value, value_unit, rate, rate_unit, is_weight_related, source_type, application_id Behaviour: - Fluids (target unit == "ml"): one row per (case, application); value = app_volume. Ingredient fan-out collapsed. - Non-fluids: value = dose_clinical * conversion_factor in the canonical drug_unit. Combo drugs: combined pseudo-ingredient preferred when present, else summed components. """ # Prepare mapping df_mapping = df_mapping.with_columns( pl.col("ingredient_id").cast(pl.Int64, strict=False), pl.col("config_drug_id").cast(pl.Int64, strict=False), ) map_sub = df_mapping.filter(drug_canonical=canonical) if cast(int, map_sub.select(pl.len()).collect().item()) == 0: raise ValueError(f"Canonical '{canonical}' not found in mapping.") cdids = cast( list[str], map_sub.select(pl.col("config_drug_id").unique()) .collect()["config_drug_id"] .to_list(), ) target_unit = cast( str | None, map_sub.select("drug_unit").drop_nulls().collect()["drug_unit"].mode().item(), ) is_fluid = target_unit == "ml" map_join = ( map_sub.filter(pl.col("ingredient_id").is_not_null()) .select( "config_drug_id", "ingredient_id", "ingredient_canonical", "ingredient_unit", "ingredient_unit_conversion_factor", "drug_canonical", "drug_unit", "drug_atc", "route", ) .unique(subset=["config_drug_id", "ingredient_id"], keep="first") ) df_apps = get_filtered_drugs(cohort, cdids, id_column) joined = df_apps.join( other=map_join, on=["config_drug_id", "ingredient_id"], how="left" ).filter(drug_canonical=canonical) # Fluid path: collapse ingredient fan-out, return volume if is_fluid: joined = joined.unique( subset=["case_id", "application_id"], keep="first" ).with_columns( _safe_f64("app_volume").alias("value"), pl.lit("ml").alias("value_unit"), ) # Non-fluid path: dose × conversion factor in canonical unit else: joined = joined.filter(pl.col("ingredient_canonical").is_not_null()) if target_ingredient is not None: joined = joined.filter(ingredient_canonical=target_ingredient) else: # Combo handling: prefer combined pseudo-ingredient when it exists joined = ( joined.with_columns( # Will create a column which is either all True or False pl.col("ingredient_canonical") .ne(canonical) .all() .alias("_no_combos") ) .filter( # Take all cols if there is no combo row or combo rows if there are some pl.col("_no_combos") | pl.col("ingredient_canonical").eq(canonical) ) .drop("_no_combos") ) joined = joined.with_columns( _safe_f64("dose_clinical") .mul( pl.col("ingredient_unit_conversion_factor").cast( pl.Float64, strict=False ) ) .alias("value"), pl.lit(target_unit).alias("value_unit"), ) joined = joined.unique( subset=["case_id", "application_id", "ingredient_canonical"], keep="first", ) return joined.select( id_column, "recordtime", "recordtime_end", "value", "value_unit", pl.struct( pl.col("drug_canonical").alias("canonical_drug_name"), pl.col("drug_atc").alias("atc_code"), "application_route", "rate", "rate_unit", "is_weight_related", "source_type", "application_id", ).alias("attributes"), ) # TODO: Use NativeExtractor LAB_SQL = """\ SELECT c_patnr AS patient_id, c_falnr AS case_id, c_dokument_timestamp AS recordtime, c_wert, c_wert_einheit, c_leistungnr, c_katalog_leistungtext, c_wert_timestamp AS storetime FROM {db_labs}.{tbl_labs} WHERE {where} AND {tmp_filter} """.strip() def get_filtered_labs( cohort: Cohort, where: str, id_column: Literal["case_id", "patient_id"] = "case_id", ) -> pl.LazyFrame: conn = helpers.reinit_conn(cohort) dbs = helpers.get_cub_hdp_database(cohort) with helpers.temporary_table(cohort, id_column=id_column) as temp_tbl: # TODO: Use NativeExtractor, which will pull in this info from TABLES table_info = {"patient_id": "c_patnr", "case_id": "c_falnr"} labs = conn.lf_read_sql( LAB_SQL.format( db_labs=dbs["prepared"], tbl_labs="it_ishmed_labor", tmp_filter=helpers.sql_tmp_filter( table_info[id_column], id_column, dbs["workbench"], temp_tbl ), where=textwrap.indent(where, " " * 8), ), datetime_cols=["recordtime", "storetime"], ) return labs def get_lab( cohort: Cohort, df_mapping: pl.LazyFrame, lab_concept: str, id_column: Literal["case_id", "patient_id"] = "case_id", ) -> pl.LazyFrame: """ Return all administrations of a canonical lab as a clean, unit-unified df. Output columns: identifier, recordtime, canonical_lab_name, loinc_code, conversion_factor, easily_parsable, raw_value, raw_unit, was_censored, relation, native_value, storetime, clean_low, clean_high, value, value_unit """ # Prepare mapping df_mapping = df_mapping.with_columns( pl.col("c_leistungnr").cast(pl.Utf8, strict=False), pl.col("c_katalog_leistungtext").cast(pl.Utf8, strict=False), pl.col("c_wert_einheit").cast(pl.Utf8, strict=False), ) map_sub = df_mapping.filter(concept=lab_concept) if cast(int, map_sub.select(pl.len()).collect().item()) == 0: raise ValueError(f"Lab concept '{lab_concept}' not found in mapping.") lab_where = cast( str, map_sub.select( pl.format( "({})", pl.concat_str( ( pl.format( "COALESCE(%s, '') = '{}'" % col, pl.coalesce(col, pl.lit("")), ) for col in ( "c_leistungnr", "c_katalog_leistungtext", "c_wert_einheit", ) ), separator=" AND ", ), ) .str.join(" OR\n") .alias("where") ) .collect() .item(), ) map_join = map_sub.select( "c_leistungnr", "c_katalog_leistungtext", "c_wert_einheit", "concept", "canonical_unit", "conversion_factor", "loinc_code", "easily_parsable", "clean_low", "clean_high", ) df_labs = get_filtered_labs(cohort, lab_where, id_column) joined = df_labs.join( other=map_join, on=["c_leistungnr", "c_katalog_leistungtext", "c_wert_einheit"] ) joined = ( joined.with_columns( pl.col("c_wert") .str.extract_groups( r"^\s*(?<relation>[<>≤≥])?\s*=?\s*(?<native_value>-?\d+(?:[.,]\d+)?)\s*$" ) .struct.unnest() ) .with_columns( pl.col("native_value").cast(pl.Float64), pl.col("relation").is_not_null().alias("was_censored"), ) .with_columns( pl.col("native_value").mul(pl.col("conversion_factor")).alias("canonical") ) .with_columns( pl.when( pl.col("canonical").is_between( pl.col("clean_low").fill_null(float("-inf")), pl.col("clean_high").fill_null(float("inf")), ) ) .then("canonical") .otherwise(None) ) ) return joined.select( id_column, "recordtime", pl.col("canonical").alias("value"), pl.col("canonical_unit").alias("value_unit"), pl.struct( pl.col("concept").alias("canonical_lab_name"), "loinc_code", "conversion_factor", "easily_parsable", pl.col("c_wert").alias("raw_value"), pl.col("c_wert_einheit").alias("raw_unit"), "was_censored", "relation", "native_value", "storetime", "clean_low", "clean_high", ).alias("attributes"), )