Source code for mir.basic.gene_usage

"""Gene usage tracking for immune repertoires.

:class:`GeneUsage` accumulates V-J gene combination counts from
:class:`~mir.common.repertoire.LocusRepertoire` or
:class:`~mir.common.repertoire.SampleRepertoire` objects and exposes joint and
marginal usage statistics together with Laplace-smoothed fractions.

Allele Handling
~~~~~~~~~~~~~~~
By default, gene allele suffixes are stripped during initialization
(e.g., ``TRBV1*01`` → ``TRBV1``) so that different allele naming conventions
are treated as the same gene. This behavior can be disabled by setting
``strip_alleles=False`` when constructing a ``GeneUsage`` object.

When resampling using :func:`mir.common.sampling.resample_to_gene_usage`,
clonotypes retain their original alleles while only stripped gene bases are
used for frequency comparison.
"""

from __future__ import annotations

from collections import defaultdict
import os
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Union

import numpy as np
import pandas as pd
import polars as pl

if TYPE_CHECKING:
    from mir.common.repertoire import LocusRepertoire, SampleRepertoire
    from mir.common.repertoire_dataset import RepertoireDataset

_AnyDataFrame = Union["pd.DataFrame", "pl.DataFrame"]

_VJPair = tuple[str, str]
GeneScope = Literal["v", "j", "vj"]


def _normalize_count_mode(count: str) -> str:
    """Normalize public count-mode aliases.

    Supported modes:
    - ``clonotypes`` / ``count_rearrangement`` (unweighted, default)
    - ``duplicates`` / ``count_duplicates`` (weighted by duplicate_count)
    """
    mode = str(count).strip().lower()
    if mode in {
        "clonotypes",
        "clonotype",
        "rearrangement",
        "rearrangements",
        "count_rearrangement",
        "count_rearrangements",
    }:
        return "clonotypes"
    if mode in {"duplicates", "duplicate", "count_duplicates"}:
        return "duplicates"
    raise ValueError(
        f"Unknown count mode: {count!r}. "
        "Use 'clonotypes'/'count_rearrangement' or "
        "'duplicates'/'count_duplicates'."
    )


def _count_index(count: str) -> int:
    """Return storage index for normalized count mode."""
    return 0 if _normalize_count_mode(count) == "clonotypes" else 1


def _laplace_fraction(usage: dict, total: int, pseudocount: float) -> dict:
    """Compute Laplace-smoothed fractions for an observed usage map."""
    n_keys = len(usage)
    denom = total + n_keys * pseudocount
    if denom == 0:
        return {}
    return {k: (n + pseudocount) / denom for k, n in usage.items()}


def _safe_group_renormalize(
    df: pd.DataFrame,
    *,
    value_col: str,
    fallback_col: str,
    group_cols: list[str],
) -> pd.Series:
    """Renormalize ``value_col`` to sum to 1 within each group.

    For groups where the value mass is non-positive or not finite, falls back
    to normalized ``fallback_col``.
    """
    out = df[value_col].to_numpy(dtype=float, copy=True)
    fallback = df[fallback_col].to_numpy(dtype=float, copy=False)
    groups = df.groupby(group_cols, dropna=False, sort=False).indices
    for idx in groups.values():
        idx_arr = np.asarray(list(idx), dtype=int)
        vals = out[idx_arr]
        vals = np.where(np.isfinite(vals), vals, 0.0)
        vals = np.clip(vals, 0.0, None)
        s = float(vals.sum())
        if s > 0.0:
            out[idx_arr] = vals / s
            continue

        fb = np.where(np.isfinite(fallback[idx_arr]), fallback[idx_arr], 0.0)
        fb = np.clip(fb, 0.0, None)
        fb_sum = float(fb.sum())
        if fb_sum > 0.0:
            out[idx_arr] = fb / fb_sum
        else:
            out[idx_arr] = 1.0 / len(idx_arr)
    return pd.Series(out, index=df.index)


def _strip_allele(gene: str) -> str:
    """Strip allele suffix: ``"TRBV1*01"`` → ``"TRBV1"``."""
    return gene.split("*")[0] if gene else ""


[docs] def precompute_olga_gene_usage_probabilities( *, species: str, locus: str, synthetic_n: int = 10_000_000, n_jobs: int | None = None, seed: int = 42, overwrite: bool = False, progress: bool = True, control_dir: str | Path | None = None, control_manager=None, control_kwargs: dict | None = None, cache_in_memory: bool = True, ) -> dict[str, dict[object, float]]: """Precompute and persist OLGA V/J/VJ usage probabilities for one model. This helper ensures a synthetic OLGA control exists on disk for the requested ``(species, locus, synthetic_n)`` and returns marginal and joint usage probabilities derived from that control. Generation can be parallelized via ``n_jobs``. Args: species: Species alias accepted by :class:`~mir.common.control.ControlManager`. locus: IMGT locus code (for example, ``"TRB"``). synthetic_n: Number of synthetic clonotypes used to estimate usage. n_jobs: Number of worker processes for synthetic generation. When ``None``, uses all available CPUs. seed: Random seed used when creating synthetic controls. overwrite: Regenerate control even if a cached artifact exists. progress: Whether to print progress during control generation. control_dir: Optional control cache root. Defaults to manager default. control_manager: Optional preconfigured :class:`~mir.common.control.ControlManager`. control_kwargs: Extra kwargs forwarded to ``ControlManager.ensure_and_load_control_df``. cache_in_memory: Whether to populate in-process OLGA usage cache. Returns: Dict with keys ``"v"``, ``"j"``, ``"vj"`` containing probability maps. """ from mir.basic.pgen import ( _GENE_USAGE_PROB_CACHE, compute_gene_usage_probabilities_from_control_df, ) from mir.common.control import ControlManager species_key = str(species).lower().strip() cache_key = (species_key, str(locus), int(synthetic_n)) if cache_in_memory and not overwrite and cache_key in _GENE_USAGE_PROB_CACHE: return _GENE_USAGE_PROB_CACHE[cache_key] manager = control_manager or ControlManager(control_dir=control_dir) kwargs = dict(control_kwargs or {}) resolved_n_jobs = int(n_jobs) if n_jobs is not None else int(os.cpu_count() or 1) kwargs.setdefault("n", int(synthetic_n)) kwargs.setdefault("n_jobs", max(1, resolved_n_jobs)) kwargs.setdefault("seed", int(seed)) kwargs.setdefault("overwrite", bool(overwrite)) kwargs.setdefault("progress", bool(progress)) control_df = manager.ensure_and_load_control_df( "synthetic", species, locus, **kwargs, ) probs = compute_gene_usage_probabilities_from_control_df(control_df) if cache_in_memory: _GENE_USAGE_PROB_CACHE[cache_key] = probs return probs
[docs] def get_gene_usage_from_olga_model( model, ) -> dict[str, dict[object, float]]: """Return V/J/VJ usage probabilities read directly from OLGA model marginals. This is an analytical alternative to :func:`precompute_olga_gene_usage_probabilities`: instead of generating millions of synthetic sequences, it reads the IGoR probability parameters from the loaded model. The result is instantaneous, deterministic, and matches the asymptotic limit of the sampling approach. Probabilities are aggregated at the major gene level (allele suffixes stripped), e.g. ``"TRBV5-1*01"`` and ``"TRBV5-1*02"`` are summed under ``"TRBV5-1"``. Args: model: A loaded :class:`~mir.basic.pgen.OlgaModel` instance. Returns: Dict with keys ``"v"``, ``"j"``, ``"vj"`` mapping to probability dicts. ``"vj"`` keys are ``(v_gene, j_gene)`` tuples. Example: >>> from mir.basic.pgen import OlgaModel >>> from mir.basic.gene_usage import get_gene_usage_from_olga_model >>> m = OlgaModel(locus="TRB", species="human") >>> gu = get_gene_usage_from_olga_model(m) >>> round(sum(gu["v"].values()), 6) 1.0 """ from collections import defaultdict import numpy as np from mir.common.alleles import strip_allele gm = model.gen_model v_alleles: list[str] = model.v_names j_alleles: list[str] = model.j_names if model.is_d_present: # VDJ model: P(V) × P(D,J) are independent; P(J) = sum_D P(D,J) pv_allele = np.asarray(gm.PV, dtype=float) pj_allele = np.asarray(gm.PDJ, dtype=float).sum(axis=0) pvj_allele = pv_allele[:, None] * pj_allele[None, :] else: # VJ model: P(V,J) stored directly as 2-D array pvj_allele = np.asarray(gm.PVJ, dtype=float) pv_allele = pvj_allele.sum(axis=1) pj_allele = pvj_allele.sum(axis=0) # Strip alleles so keys match GeneUsage.strip_alleles=True convention: # "TRBV5-1*01" and "TRBV5-1*02" both aggregate under "TRBV5-1". v_genes = [strip_allele(a) for a in v_alleles] j_genes = [strip_allele(a) for a in j_alleles] p_v: dict[str, float] = defaultdict(float) for v, p in zip(v_genes, pv_allele): p_v[v] += float(p) p_j: dict[str, float] = defaultdict(float) for j, p in zip(j_genes, pj_allele): p_j[j] += float(p) p_vj: dict[tuple, float] = defaultdict(float) for vi, v in enumerate(v_genes): for ji, j in enumerate(j_genes): p_vj[(v, j)] += float(pvj_allele[vi, ji]) return { "v": dict(p_v), "j": dict(p_j), "vj": dict(p_vj), }
[docs] class GeneUsage: """Joint and marginal V-J gene usage statistics. Stores per-locus clonotype counts and duplicate-count totals for every observed (V-gene, J-gene) pair. Parameters ---------- strip_alleles : bool, optional When ``True`` (default), remove allele suffixes during initialization so that ``TRBV1*01`` and ``TRBV1`` are treated as the same gene. When ``False``, alleles are preserved as-is. Attributes ---------- strip_alleles : bool Whether allele suffixes were stripped during initialization. Examples -------- Build from a repertoire, automatically stripping alleles:: gu = GeneUsage.from_repertoire(trb_repertoire) gu.vj_fraction("TRB") {('TRBV12-3', 'TRBJ1-2'): 0.42, ...} Build with alleles preserved:: gu = GeneUsage.from_repertoire(trb_repertoire, strip_alleles=False) gu.vj_fraction("TRB") {('TRBV12-3*01', 'TRBJ1-2*01'): 0.42, ...} """ def __init__(self, *, strip_alleles: bool = True) -> None: # locus → {(v_base, j_base): [n_clones, n_dc]} self._data: dict[str, dict[_VJPair, list[int]]] = {} # locus → [total_clones, total_dc] self._totals: dict[str, list[int]] = {} self.strip_alleles = strip_alleles # ------------------------------------------------------------------ # Construction # ------------------------------------------------------------------
[docs] @classmethod def from_repertoire( cls, repertoire: "LocusRepertoire", *, locus: str = "", strip_alleles: bool = True, ) -> "GeneUsage": """Build from a :class:`~mir.common.repertoire.LocusRepertoire`. Parameters ---------- repertoire Source locus repertoire. locus Override locus. When empty the repertoire's own locus is used. strip_alleles Whether to strip allele suffixes (default ``True``). """ obj = cls(strip_alleles=strip_alleles) obj._add_locus_repertoire(repertoire, locus=locus) return obj
[docs] @classmethod def from_sample( cls, sample: "SampleRepertoire", *, strip_alleles: bool = True, ) -> "GeneUsage": """Build from a :class:`~mir.common.repertoire.SampleRepertoire`. Iterates over all loci in the sample. Parameters ---------- sample Source sample repertoire. strip_alleles Whether to strip allele suffixes (default ``True``). """ obj = cls(strip_alleles=strip_alleles) for loc, locus_rep in sample.loci.items(): obj._add_locus_repertoire(locus_rep, locus=loc) return obj
[docs] @classmethod def from_list( cls, repertoires, *, strip_alleles: bool = True, ) -> "GeneUsage": """Build by accumulating data from a list of repertoire objects. Each element may be a :class:`~mir.common.repertoire.LocusRepertoire` or a :class:`~mir.common.repertoire.SampleRepertoire`. Parameters ---------- repertoires List of LocusRepertoire or SampleRepertoire objects. strip_alleles Whether to strip allele suffixes (default ``True``). """ from mir.common.repertoire import SampleRepertoire obj = cls(strip_alleles=strip_alleles) for rep in repertoires: if isinstance(rep, SampleRepertoire): for loc, locus_rep in rep.loci.items(): obj._add_locus_repertoire(locus_rep, locus=loc) else: obj._add_locus_repertoire(rep) return obj
[docs] @classmethod def from_dataframe( cls, table: "_AnyDataFrame", *, locus: str, v_col: str = "v_gene", j_col: str = "j_gene", duplicate_count_col: str = "duplicate_count", strip_alleles: bool = True, ) -> "GeneUsage": """Build from a DataFrame with V/J columns. Parameters ---------- table Input table (pandas or polars) containing V/J gene fields and optionally duplicate counts. locus IMGT locus code to assign to all rows. v_col, j_col Column names for V/J genes. duplicate_count_col Column name for duplicate count. If absent, duplicates default to 1. strip_alleles Whether to strip allele suffixes (default ``True``). """ # Normalise to polars for uniform processing. if not isinstance(table, pl.DataFrame): tbl = pl.from_pandas(table) else: tbl = table if v_col not in tbl.columns or j_col not in tbl.columns: raise ValueError( f"DataFrame must contain columns {v_col!r} and {j_col!r}" ) obj = cls(strip_alleles=strip_alleles) locus_data = obj._data.setdefault(locus, {}) locus_totals = obj._totals.setdefault(locus, [0, 0]) cols = [v_col, j_col] if duplicate_count_col in tbl.columns: tbl = tbl.select(cols + [duplicate_count_col]).with_columns( pl.col(duplicate_count_col).cast(pl.Int64, strict=False).fill_null(1) ) else: tbl = tbl.select(cols).with_columns(pl.lit(1).alias(duplicate_count_col)) tbl = ( tbl .with_columns([ pl.col(v_col).cast(pl.Utf8).fill_null(""), pl.col(j_col).cast(pl.Utf8).fill_null(""), ]) .filter((pl.col(v_col) != "") & (pl.col(j_col) != "")) ) if tbl.is_empty(): return obj grouped = ( tbl.group_by([v_col, j_col]) .agg([ pl.len().alias("n_clones"), pl.col(duplicate_count_col).sum().alias("n_dc"), ]) ) for row in grouped.iter_rows(named=True): v = obj._normalize_gene(str(row[v_col] or "")) j = obj._normalize_gene(str(row[j_col] or "")) n_clones = int(row["n_clones"]) n_dc = int(row["n_dc"]) entry = locus_data.setdefault((v, j), [0, 0]) entry[0] += n_clones entry[1] += n_dc locus_totals[0] += n_clones locus_totals[1] += n_dc return obj
def _add_locus_repertoire(self, repertoire, *, locus: str = "") -> None: loc = locus or repertoire.locus or "" locus_data = self._data.setdefault(loc, {}) locus_totals = self._totals.setdefault(loc, [0, 0]) table = getattr(repertoire, "_polars_table", None) if table is not None: try: import polars as pl if table.height == 0: return grouped = ( table .select([ pl.col("v_gene").cast(pl.Utf8).fill_null(""), pl.col("j_gene").cast(pl.Utf8).fill_null(""), pl.col("duplicate_count").cast(pl.Int64).fill_null(0), ]) .group_by(["v_gene", "j_gene"]) .agg([ pl.len().alias("n_clones"), pl.col("duplicate_count").sum().alias("n_dc"), ]) ) for row in grouped.iter_rows(named=True): v = self._normalize_gene(str(row.get("v_gene") or "")) j = self._normalize_gene(str(row.get("j_gene") or "")) n_clones = int(row.get("n_clones") or 0) n_dc = int(row.get("n_dc") or 0) entry = locus_data.setdefault((v, j), [0, 0]) entry[0] += n_clones entry[1] += n_dc locus_totals[0] += n_clones locus_totals[1] += n_dc return except Exception: # Fall back to the generic Python path if polars operations fail. pass # Fast path for lazily loaded repertoires: consume raw columns directly # and avoid constructing per-clonotype Python objects. pending = getattr(repertoire, "_pending_cols", None) if pending is not None: v_genes = pending.get("v_genes", []) j_genes = pending.get("j_genes", []) dups = pending.get("dup_counts", []) for v_gene, j_gene, dc in zip(v_genes, j_genes, dups): v = self._normalize_gene(v_gene or "") j = self._normalize_gene(j_gene or "") dc_i = int(dc or 0) entry = locus_data.setdefault((v, j), [0, 0]) entry[0] += 1 entry[1] += dc_i locus_totals[0] += 1 locus_totals[1] += dc_i return for clone in repertoire.clonotypes: v = self._normalize_gene(clone.v_gene or "") j = self._normalize_gene(clone.j_gene or "") dc = clone.duplicate_count or 0 entry = locus_data.setdefault((v, j), [0, 0]) entry[0] += 1 entry[1] += dc locus_totals[0] += 1 locus_totals[1] += dc def _normalize_gene(self, gene: str) -> str: """Apply gene normalization based on strip_alleles setting.""" return _strip_allele(gene) if self.strip_alleles else gene # ------------------------------------------------------------------ # Loci # ------------------------------------------------------------------ @property def loci(self) -> list[str]: """Loci with observed data.""" return list(self._data.keys()) # ------------------------------------------------------------------ # Totals # ------------------------------------------------------------------
[docs] def total(self, locus: str, *, count: str = "clonotypes") -> int: """Total count for *locus*. Args: locus: IMGT locus code. count: ``"clonotypes"`` (unique rearrangements) or ``"duplicates"``. """ normalized = _normalize_count_mode(count) totals = self._totals.get(locus, [0, 0]) return totals[0] if normalized == "clonotypes" else totals[1]
# ------------------------------------------------------------------ # Usage accessors # ------------------------------------------------------------------
[docs] def vj_usage( self, locus: str, *, count: str = "clonotypes", ) -> dict[_VJPair, int]: """Joint V-J usage for *locus*. Args: locus: IMGT locus code. count: ``"clonotypes"`` or ``"duplicates"``. Returns: Dict mapping ``(v_base, j_base)`` to the requested count. """ idx = _count_index(count) return {pair: vals[idx] for pair, vals in self._data.get(locus, {}).items()}
def _marginal_usage(self, locus: str, *, count: str, axis: int) -> dict[str, int]: """Generic V/J marginal usage helper. Parameters ---------- axis 0 for V-gene aggregation, 1 for J-gene aggregation. """ idx = _count_index(count) result: dict[str, int] = defaultdict(int) for pair, vals in self._data.get(locus, {}).items(): result[pair[axis]] += vals[idx] return dict(result)
[docs] def v_usage( self, locus: str, *, count: str = "clonotypes", ) -> dict[str, int]: """Marginal V-gene usage (sum over all J) for *locus*.""" return self._marginal_usage(locus, count=count, axis=0)
[docs] def j_usage( self, locus: str, *, count: str = "clonotypes", ) -> dict[str, int]: """Marginal J-gene usage (sum over all V) for *locus*.""" return self._marginal_usage(locus, count=count, axis=1)
# ------------------------------------------------------------------ # Fractions with Laplace smoothing # ------------------------------------------------------------------
[docs] def vj_fraction( self, locus: str, *, count: str = "clonotypes", pseudocount: float = 1.0, ) -> dict[_VJPair, float]: """Laplace-smoothed V-J fraction for *locus*. Fractions sum to 1 over observed pairs using:: (n_observed + pseudocount) / (total + n_observed_pairs * pseudocount) Args: locus: IMGT locus code. count: ``"clonotypes"`` or ``"duplicates"``. pseudocount: Added to each count and the denominator term. """ usage = self.vj_usage(locus, count=count) return _laplace_fraction(usage, self.total(locus, count=count), pseudocount)
[docs] def v_fraction( self, locus: str, *, count: str = "clonotypes", pseudocount: float = 1.0, ) -> dict[str, float]: """Laplace-smoothed marginal V-gene fraction for *locus*.""" usage = self.v_usage(locus, count=count) return _laplace_fraction(usage, self.total(locus, count=count), pseudocount)
[docs] def j_fraction( self, locus: str, *, count: str = "clonotypes", pseudocount: float = 1.0, ) -> dict[str, float]: """Laplace-smoothed marginal J-gene fraction for *locus*.""" usage = self.j_usage(locus, count=count) return _laplace_fraction(usage, self.total(locus, count=count), pseudocount)
def _usage_by_scope(self, locus: str, *, scope: str, count: str) -> dict: """Dispatch helper for v/j/vj usage maps.""" scope_norm = str(scope).strip().lower() if scope_norm == "v": return self.v_usage(locus, count=count) if scope_norm == "j": return self.j_usage(locus, count=count) if scope_norm == "vj": return self.vj_usage(locus, count=count) raise ValueError("scope must be one of: 'v', 'j', 'vj'") # ------------------------------------------------------------------ # Cross-dataset comparison helpers # ------------------------------------------------------------------
[docs] def usage_comparison( self, reference: "GeneUsage", locus: str, *, scope: str = "vj", count: str = "count_rearrangement", pseudocount: float = 1.0, ) -> dict[object, dict[str, float]]: """Compare smoothed usage frequencies against another GeneUsage. Frequencies are computed independently for ``self`` and ``reference`` using Laplace smoothing with the same pseudocount: ``(n_key + pseudocount) / (total + n_observed_keys * pseudocount)``. Args: reference: Baseline gene usage to compare against (e.g. OLGA). locus: IMGT locus code. scope: ``"v"``, ``"j"``, or ``"vj"``. count: Count mode alias (default ``count_rearrangement``). pseudocount: Additive smoothing constant (must be >= 0). Returns: Mapping from key (gene or VJ tuple) to: ``{"p_self": ..., "p_reference": ..., "factor": ...}``. """ if pseudocount < 0: raise ValueError("pseudocount must be non-negative") self_usage = self._usage_by_scope(locus, scope=scope, count=count) ref_usage = reference._usage_by_scope(locus, scope=scope, count=count) all_keys = sorted(set(self_usage) | set(ref_usage)) self_total = self.total(locus, count=count) ref_total = reference.total(locus, count=count) n_all = len(all_keys) self_denom = self_total + n_all * pseudocount ref_denom = ref_total + n_all * pseudocount result: dict[object, dict[str, float]] = {} for key in all_keys: p_self = ( (self_usage.get(key, 0) + pseudocount) / self_denom if self_denom > 0 else 0.0 ) p_ref = ( (ref_usage.get(key, 0) + pseudocount) / ref_denom if ref_denom > 0 else 0.0 ) factor = (p_self / p_ref) if p_ref > 0 else float("inf") result[key] = { "p_self": float(p_self), "p_reference": float(p_ref), "factor": float(factor), } return result
[docs] def correction_factors( self, reference: "GeneUsage", locus: str, *, scope: str = "vj", count: str = "count_rearrangement", pseudocount: float = 1.0, ) -> dict[object, float]: """Return correction factors ``P_self / P_reference`` by key.""" comparison = self.usage_comparison( reference, locus, scope=scope, count=count, pseudocount=pseudocount, ) return {k: v["factor"] for k, v in comparison.items()}
# ------------------------------------------------------------------ # Batch correction utilities # ------------------------------------------------------------------
[docs] def zscore_to_sigmoid(z: "np.ndarray | float") -> "np.ndarray | float": """Map a (batch-corrected) z-score to a bounded sigmoid value in ``(0, 1)``. ``sigmoid(z) = 1 / (1 + exp(-z))`` This is the canonical transform to turn per-gene z-scores from :func:`compute_batch_corrected_gene_usage` into bounded, comparable corrected probabilities that can be directly used in PCA/UMAP embeddings. Parameters ---------- z Scalar or array of z-scores. Returns ------- np.ndarray or float with the same shape as *z*, values in ``(0, 1)``. """ arr = np.asarray(z, dtype=float) result = 1.0 / (1.0 + np.exp(-arr)) return float(result) if arr.ndim == 0 else result
def _winsorized_mean_std(values, *, lower_q: float = 0.025, upper_q: float = 0.975) -> tuple[float, float]: """Return mean and SD after clipping to the winsorized interval.""" arr = np.asarray(values, dtype=float) arr = arr[np.isfinite(arr)] if arr.size == 0: return 0.0, 0.0 lo = float(np.quantile(arr, lower_q)) hi = float(np.quantile(arr, upper_q)) clipped = np.clip(arr, lo, hi) mean = float(np.mean(clipped)) std = float(np.std(clipped, ddof=1)) if clipped.size > 1 else 0.0 if not np.isfinite(std): std = 0.0 return mean, std
[docs] def compute_batch_corrected_gene_usage( dataset: "RepertoireDataset", *, batch_field: str = "batch_id", scope: GeneScope = "vj", weighted: bool = True, pseudocount: float = 1.0, z_cap: float = 6.0, ) -> pd.DataFrame: """Compute batch-corrected gene usage for all samples/loci/genes. Uses a pseudocount on raw counts prior to normalization: ``p = (count + pseudocount) / (total + pseudocount * n_genes)`` Then computes ``log_p``, batch-wise winsorized (95%) ``mu`` and ``sigma`` over ``(locus, gene, batch_id)``, capped z-scores, and final corrected probabilities: ``correction_factor = exp(z)`` ``pfinal_raw = p * correction_factor`` Finally, for each ``(sample_id, locus)`` group we renormalize ``pfinal`` so probabilities sum to 1. If a group's raw corrected mass is invalid or non-positive, we fall back to normalized raw ``p`` for that group. Empty sample loci and loci absent in a sample are skipped without error. """ if pseudocount < 0: raise ValueError("pseudocount must be >= 0") if z_cap <= 0: raise ValueError("z_cap must be > 0") count_mode = "duplicates" if weighted else "clonotypes" sample_usage: dict[tuple[str, str], dict[object, int]] = {} genes_by_locus: dict[str, set[object]] = defaultdict(set) for sample_id, sample in dataset.samples.items(): gu = GeneUsage.from_sample(sample) for locus, locus_rep in sample.loci.items(): if locus_rep is None or getattr(locus_rep, "clonotype_count", 0) == 0: continue usage = gu._usage_by_scope(locus, scope=scope, count=count_mode) sample_usage[(sample_id, locus)] = usage genes_by_locus[locus].update(usage.keys()) columns = [ "sample_id", "batch_id", "locus", "gene", "count", "total", "n_genes", "p", "log_p", "mu", "sigma", "z", "pavg", "pfinal", ] if not sample_usage: return pd.DataFrame(columns=columns) pooled_counts: dict[tuple[str, object], float] = defaultdict(float) pooled_totals: dict[str, float] = defaultdict(float) for (_, locus), usage in sample_usage.items(): for gene, val in usage.items(): pooled_counts[(locus, gene)] += float(val) pooled_totals[locus] += float(val) pavg: dict[tuple[str, object], float] = {} for locus, genes in genes_by_locus.items(): denom = float(pooled_totals.get(locus, 0.0)) if denom <= 0: for gene in genes: pavg[(locus, gene)] = 0.0 continue for gene in genes: pavg[(locus, gene)] = pooled_counts[(locus, gene)] / denom rows: list[dict[str, object]] = [] for sample_id, sample in dataset.samples.items(): metadata = dataset.metadata.get(sample_id, {}) if batch_field not in metadata: raise ValueError(f"metadata for sample_id={sample_id!r} missing required field {batch_field!r}") batch_id = metadata[batch_field] for locus, locus_rep in sample.loci.items(): if locus not in genes_by_locus: continue if locus_rep is None or getattr(locus_rep, "clonotype_count", 0) == 0: continue usage = sample_usage.get((sample_id, locus), {}) n_genes = len(genes_by_locus[locus]) if weighted: total = float(getattr(locus_rep, "duplicate_count", 0)) else: total = float(getattr(locus_rep, "clonotype_count", 0)) denom = total + pseudocount * n_genes for gene in sorted(genes_by_locus[locus]): count = float(usage.get(gene, 0.0)) p = ((count + pseudocount) / denom) if denom > 0 else 0.0 log_p = float(np.log(p)) if p > 0 else float("-inf") rows.append( { "sample_id": sample_id, "batch_id": batch_id, "locus": locus, "gene": gene, "count": count, "total": total, "n_genes": n_genes, "p": p, "log_p": log_p, } ) df = pd.DataFrame(rows) if df.empty: return pd.DataFrame(columns=columns) stats = ( df.groupby(["locus", "gene", "batch_id"], dropna=False)["log_p"] .apply(_winsorized_mean_std) .reset_index(name="mu_sigma") ) stats[["mu", "sigma"]] = pd.DataFrame(stats["mu_sigma"].tolist(), index=stats.index) stats = stats.drop(columns=["mu_sigma"]) df = df.merge(stats, on=["locus", "gene", "batch_id"], how="left") df["sigma"] = pd.to_numeric(df["sigma"], errors="coerce").fillna(0.0) df["mu"] = pd.to_numeric(df["mu"], errors="coerce").fillna(0.0) raw_z = np.where(df["sigma"].to_numpy(dtype=float) > 0, (df["log_p"].to_numpy(dtype=float) - df["mu"].to_numpy(dtype=float)) / df["sigma"].to_numpy(dtype=float), 0.0) df["z"] = np.clip(raw_z, -z_cap, z_cap) df["pavg"] = [pavg[key] for key in zip(df["locus"], df["gene"])] correction_factor = np.exp(df["z"].to_numpy(dtype=float)) df["pfinal"] = df["p"].to_numpy(dtype=float) * correction_factor df["pfinal"] = _safe_group_renormalize( df, value_col="pfinal", fallback_col="p", group_cols=["sample_id", "locus"], ) return df[columns].sort_values(["sample_id", "locus", "gene"]).reset_index(drop=True)
def _extract_marginal_gene(gene: object, *, axis: int) -> str: """Return V or J component from a VJ key object. Parameters ---------- gene VJ key from ``compute_batch_corrected_gene_usage(..., scope='vj')``. Expected shape is ``(v_gene, j_gene)``. axis ``0`` for V-gene, ``1`` for J-gene. """ if isinstance(gene, tuple) and len(gene) > axis: return str(gene[axis]) if isinstance(gene, list) and len(gene) > axis: return str(gene[axis]) raise ValueError( "Expected VJ tuple/list genes in input DataFrame. " "Run compute_batch_corrected_gene_usage(..., scope='vj') first." )
[docs] def marginalize_batch_corrected_gene_usage( df: pd.DataFrame, *, scope: Literal["v", "j"], ) -> pd.DataFrame: """Marginalize batch-corrected VJ usage to V or J usage. This helper converts output from :func:`compute_batch_corrected_gene_usage` computed with ``scope='vj'`` into V- or J-marginal usage by summing over the opposite dimension. Parameters ---------- df DataFrame from ``compute_batch_corrected_gene_usage(..., scope='vj')``. Required columns: ``sample_id``, ``batch_id``, ``locus``, ``gene``, ``p``, ``pfinal``, ``pavg``. scope Target marginal scope: ``"v"`` (sum over J) or ``"j"`` (sum over V). Returns ------- pd.DataFrame Columns: ``sample_id``, ``batch_id``, ``locus``, ``gene``, ``p``, ``pfinal``, ``pavg``. """ required = {"sample_id", "batch_id", "locus", "gene", "p", "pfinal", "pavg"} missing = required.difference(df.columns) if missing: raise ValueError( f"Input DataFrame missing required columns: {sorted(missing)}" ) scope_norm = str(scope).strip().lower() if scope_norm not in {"v", "j"}: raise ValueError("scope must be one of: 'v', 'j'") axis = 0 if scope_norm == "v" else 1 tmp = df.loc[:, ["sample_id", "batch_id", "locus", "gene", "p", "pfinal", "pavg"]].copy() tmp["gene"] = tmp["gene"].map(lambda g: _extract_marginal_gene(g, axis=axis)) out = ( tmp.groupby(["sample_id", "batch_id", "locus", "gene"], as_index=False, sort=True) [["p", "pfinal", "pavg"]] .sum() ) out["p"] = _safe_group_renormalize( out, value_col="p", fallback_col="p", group_cols=["sample_id", "locus"], ) out["pfinal"] = _safe_group_renormalize( out, value_col="pfinal", fallback_col="p", group_cols=["sample_id", "locus"], ) pavg_ref = ( out.groupby(["locus", "gene"], as_index=False, sort=True)["pavg"] .mean() ) pavg_ref["pavg"] = _safe_group_renormalize( pavg_ref, value_col="pavg", fallback_col="pavg", group_cols=["locus"], ) out = out.drop(columns=["pavg"]).merge( pavg_ref, on=["locus", "gene"], how="left", ) return out[["sample_id", "batch_id", "locus", "gene", "p", "pfinal", "pavg"]].sort_values( ["sample_id", "locus", "gene"] ).reset_index(drop=True)