TCREmp VDJdb Analysis (mirpy)#

This notebook reproduces the core tcremp_vdjdb workflow using mirpy APIs and AIRR benchmark assets from Hugging Face.

Pipeline:

  1. Load VDJdb slim from isalgo/airr_benchmark.

  2. Build TRB clonotypes and labels (antigen.epitope).

  3. Embed with TCREmp.

  4. Evaluate representation with PCA/UMAP, DBSCAN, and classifiers.

  5. Compare fixed_gap vs biopython backends (runtime + predictive quality).

[1]:
import importlib.metadata as _meta
import sys as _sys
print(f"Python {_sys.version.split()[0]}")
for _pkg in ["mirpy-lib", "numpy", "pandas", "matplotlib", "scipy", "polars"]:
    try:
        print(f"  {_pkg}: {_meta.version(_pkg)}")
    except _meta.PackageNotFoundError:
        pass

# Setup imports, deterministic seed, and package versions for reproducibility.
import subprocess, sys, importlib

def _ensure(pkg, import_name=None):
    name = import_name or pkg
    try:
        importlib.import_module(name)
    except ImportError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", pkg, "-q"])

_ensure("kneed")
_ensure("umap-learn", "umap")

import os, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import (
    adjusted_rand_score,
    normalized_mutual_info_score,
    silhouette_score,
)
import umap
from kneed import KneeLocator

repo_root = Path.cwd().resolve().parent if Path.cwd().name == "notebooks" else Path.cwd().resolve()
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

from mir.common.clonotype import Clonotype
from mir.embedding.tcremp import TCREmp
from mir.utils.notebook_assets import ensure_airr_benchmark, find_airr_benchmark_vdjdb_slim

SEED = 42
np.random.seed(SEED)

# Publication-ready matplotlib style
plt.rcParams.update({
    "font.family": "sans-serif",
    "font.size": 10,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "figure.dpi": 120,
    "axes.labelsize": 10,
    "axes.titlesize": 11,
    "legend.fontsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
})

# Top-11 focal epitopes + "other" with colorblind-safe palette
FOCAL_EPITOPES = ["CIN", "ELA", "GIL", "GLC", "LLW", "NLV", "PKY", "SPR", "TFE", "TTD", "YLQ"]
# 11 distinct, publication-ready colours (ColorBrewer Set1 + extras)
_COLORS_11 = [
    "#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00",
    "#a65628", "#f781bf", "#17becf", "#e7298a", "#1b9e77", "#d95f02",
]
EPITOPE_PALETTE = dict(zip(FOCAL_EPITOPES, _COLORS_11))
EPITOPE_PALETTE["other"] = "#cccccc"

print(f"Python: {sys.version.split()[0]}")
print(f"numpy: {np.__version__}  pandas: {pd.__version__}  sklearn: {__import__('sklearn').__version__}")
print(f"umap-learn: {umap.__version__}  kneed: {__import__('kneed').__version__}")
Python 3.12.12
  mirpy-lib: 1.1.0
  numpy: 1.26.4
  pandas: 3.0.3
  matplotlib: 3.10.9
  scipy: 1.17.1
  polars: 1.40.1
/Users/mikesh/vcs/mirpy/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Python: 3.12.12
numpy: 1.26.4  pandas: 3.0.3  sklearn: 1.8.0
umap-learn: 0.5.12  kneed: 0.8.6
[2]:
# Download/load VDJdb slim from AIRR benchmark assets.
benchmark_root = ensure_airr_benchmark(repo_root, allow_patterns=["vdjdb/**"])
vdjdb_path = find_airr_benchmark_vdjdb_slim(benchmark_root)
vdjdb = pd.read_csv(vdjdb_path, sep='\t')

print(f"Loaded: {vdjdb_path}")
print(f"Rows: {len(vdjdb):,}  Columns: {len(vdjdb.columns)}")
print("All columns:", vdjdb.columns.tolist())
print("\ngene counts:\n", vdjdb['gene'].value_counts().to_string())
if 'complex.id' in vdjdb.columns:
    print(f"\nPaired records (complex.id != 0): {(vdjdb['complex.id'] != 0).sum():,}")
if 'reference.id' in vdjdb.columns:
    tenx = vdjdb['reference.id'].astype(str).str.contains('10[xX]', na=False)
    print(f"10x Genomics records: {tenx.sum():,}")
Loaded: /Users/mikesh/vcs/mirpy/notebooks/assets/large/airr_benchmark/vdjdb/vdjdb-2025-12-29/vdjdb.slim.txt.gz
Rows: 145,408  Columns: 18
All columns: ['gene', 'cdr3', 'species', 'antigen.epitope', 'antigen.gene', 'antigen.species', 'complex.id', 'v.segm', 'j.segm', 'mhc.a', 'mhc.b', 'mhc.class', 'reference.id', 'vdjdb.score', 'vdjdb.pgen.score', 'TCR_hash', 'j.start', 'v.end']

gene counts:
 gene
TRB    95315
TRA    50093

Paired records (complex.id != 0): 145,408
10x Genomics records: 34,073
[3]:
# Map 3-letter epitope abbreviations (as used in TCREmp paper) to full peptide sequences in VDJdb.
# Search by prefix in non-10x records so the mapping is data-driven.
_df_nref = vdjdb[
    ~vdjdb['reference.id'].astype(str).str.contains('10[xX]', na=False) &
    vdjdb['species'].astype(str).str.lower().isin(['homosapiens', 'human'])
].copy()

prefixes = ["CIN", "ELA", "GIL", "GLC", "LLW", "NLV", "PKY", "SPR", "TFE", "TTD", "YLQ"]
FOCAL_MAP = {}  # short code -> full epitope sequence

for pfx in prefixes:
    hits = _df_nref[_df_nref['antigen.epitope'].str.startswith(pfx, na=False)]['antigen.epitope'].value_counts()
    if len(hits):
        FOCAL_MAP[pfx] = hits.index[0]

# Reverse map: full sequence -> short code for display labels
FOCAL_SEQ2CODE = {v: k for k, v in FOCAL_MAP.items()}
FOCAL_SEQS = list(FOCAL_MAP.values())

# Update EPITOPE_PALETTE to key on short codes (already defined in imports)
print("Epitope code → full sequence mapping:")
for code, seq in sorted(FOCAL_MAP.items()):
    print(f"  {code:4s}{seq}")

# Paired records summary (10x removed)
_c_tra = set(_df_nref[_df_nref['gene'] == 'TRA']['complex.id'].astype(str)) - {'0'}
_c_trb = set(_df_nref[_df_nref['gene'] == 'TRB']['complex.id'].astype(str)) - {'0'}
_both = _c_tra & _c_trb
print(f"\nNon-10x paired clones available: {len(_both):,}")
focal_trb = _df_nref[
    (_df_nref['gene'] == 'TRB') & (_df_nref['antigen.epitope'].isin(FOCAL_SEQS))
]
print(f"TRB records in focal epitopes: {len(focal_trb):,}")
Epitope code → full sequence mapping:
  CIN  → CINGVCWTV
  ELA  → ELAGIGILTV
  GIL  → GILGFVFTL
  GLC  → GLCTLVAML
  LLW  → LLWNGPMAV
  NLV  → NLVPMVATV
  PKY  → PKYVKQNTLKLAT
  SPR  → SPRWYFYYL
  TFE  → TFEYVSQPFLMDLE
  TTD  → TTDPSFLGRY
  YLQ  → YLQPRTFLL

Non-10x paired clones available: 13,893
TRB records in focal epitopes: 23,354
[4]:
# Filter VDJdb: paired records (complex.id != '0'), remove 10x Genomics.
# Label epitopes: focal short-code or 'other'.
# Balance by CLONE (complex.id), not by individual chain, to preserve TRA-TRB pairing.

def filter_both_chains(df):
    """Keep only complex.ids with both TRA and TRB present."""
    cids_tra = set(df[df['gene'] == 'TRA']['complex.id'].astype(str))
    cids_trb = set(df[df['gene'] == 'TRB']['complex.id'].astype(str))
    valid = (cids_tra & cids_trb) - {'0'}
    return df[df['complex.id'].astype(str).isin(valid)].copy()

df_raw = vdjdb.copy()

# Step 1: Remove 10x Genomics, keep HomoSapiens, paired only
df_raw = df_raw[~df_raw['reference.id'].astype(str).str.contains('10[xX]', na=False)]
df_raw = df_raw[df_raw['species'].astype(str).str.lower().isin(['homosapiens', 'human'])]
df_raw = filter_both_chains(df_raw)
print(f"Paired non-10x: {len(df_raw):,} records  ({df_raw['complex.id'].nunique():,} clones)")

# Step 2: CDR3 quality filter
for col in ['cdr3', 'v.segm', 'j.segm', 'antigen.epitope']:
    df_raw[col] = df_raw[col].astype(str).str.strip()
df_raw = df_raw[
    df_raw['cdr3'].str.len().between(5, 29) &
    (df_raw['v.segm'].str.len() > 0) &
    (df_raw['j.segm'].str.len() > 0) &
    (~df_raw['antigen.epitope'].isin(['nan', '']))
]
df_raw = df_raw.drop_duplicates(subset=['gene', 'cdr3', 'v.segm', 'j.segm', 'antigen.epitope'])
df_raw = filter_both_chains(df_raw)

# Step 3: Balance by complex.id (clone-level), not record-level, to keep TRA-TRB pairs intact.
# Use epitope annotation from TRB to label each clone, then cap clones per epitope.
MAX_FOCAL_CLONES = 500
MAX_OTHER_CLONES = 300

trb_df = df_raw[df_raw['gene'] == 'TRB'][['complex.id', 'antigen.epitope']].copy()
trb_df['epitope_cat'] = trb_df['antigen.epitope'].map(FOCAL_SEQ2CODE).fillna('other')

rng = np.random.RandomState(SEED)
sampled_cids = []
for ep_full, grp in trb_df.groupby('antigen.epitope'):
    cat = FOCAL_SEQ2CODE.get(ep_full, 'other')
    cap = MAX_FOCAL_CLONES if ep_full in FOCAL_SEQS else MAX_OTHER_CLONES
    cids = grp['complex.id'].unique()
    chosen = rng.choice(cids, size=min(len(cids), cap), replace=False)
    sampled_cids.extend(chosen.tolist())

df_bal = df_raw[df_raw['complex.id'].isin(sampled_cids)].copy()

# Step 4: Map epitope_cat via TRB annotation → propagate to TRA via complex.id
cid_cat = trb_df.set_index('complex.id')['epitope_cat'].to_dict()
df_bal['epitope_cat'] = df_bal['complex.id'].map(cid_cat).fillna('other')

# Step 5: Add CDR3 length column
df_bal['cdr3_len'] = df_bal['cdr3'].str.len()

df_tra = df_bal[df_bal['gene'] == 'TRA'].reset_index(drop=True)
df_trb = df_bal[df_bal['gene'] == 'TRB'].reset_index(drop=True)

print(f"\nFinal  TRA: {len(df_tra):,}   TRB: {len(df_trb):,}")
print(f"\nTRB epitope_cat value counts:")
print(df_trb['epitope_cat'].value_counts().to_string())
Paired non-10x: 27,786 records  (13,893 clones)

Final  TRA: 8,429   TRB: 8,429

TRB epitope_cat value counts:
epitope_cat
other    5822
GIL       500
NLV       500
TTD       353
SPR       346
YLQ       294
CIN       200
LLW       164
GLC        90
TFE        64
ELA        50
PKY        46
[5]:
# Build separate TRA and TRB Clonotype lists for independent embedding.
def build_clonotypes(df, locus):
    return [
        Clonotype(
            sequence_id=str(i),
            locus=locus,
            v_gene=row['v.segm'],
            j_gene=row['j.segm'],
            junction_aa=row['cdr3'],
            duplicate_count=1,
            _validate=False,
        )
        for i, row in df.iterrows()
    ]

clono_tra = build_clonotypes(df_tra, 'TRA')
clono_trb = build_clonotypes(df_trb, 'TRB')

labels_tra = df_tra['epitope_cat'].to_numpy()
labels_trb = df_trb['epitope_cat'].to_numpy()

print(f"TRA clonotypes: {len(clono_tra):,}   TRB clonotypes: {len(clono_trb):,}")
print(f"TRA unique labels: {sorted(set(labels_tra))}")
print(f"TRB unique labels: {sorted(set(labels_trb))}")
TRA clonotypes: 8,429   TRB clonotypes: 8,429
TRA unique labels: ['CIN', 'ELA', 'GIL', 'GLC', 'LLW', 'NLV', 'PKY', 'SPR', 'TFE', 'TTD', 'YLQ', 'other']
TRB unique labels: ['CIN', 'ELA', 'GIL', 'GLC', 'LLW', 'NLV', 'PKY', 'SPR', 'TFE', 'TTD', 'YLQ', 'other']
[6]:
# Embed TRA and TRB chains separately with TCREmp fixed-gap backend (1000 prototypes each).
N_PROTO = 1000

model_tra = TCREmp.from_defaults(species='human', locus='TRA', n_prototypes=N_PROTO, junction_method='fixed_gap')
model_trb = TCREmp.from_defaults(species='human', locus='TRB', n_prototypes=N_PROTO, junction_method='fixed_gap')

t0 = time.perf_counter()
X_tra_raw = model_tra.embed(clono_tra, n_jobs=None)
t_tra = time.perf_counter() - t0

t0 = time.perf_counter()
X_trb_raw = model_trb.embed(clono_trb, n_jobs=None)
t_trb = time.perf_counter() - t0

print(f"TRA: {X_tra_raw.shape}  {t_tra:.1f}s  "
      f"({len(clono_tra) * N_PROTO / max(t_tra, 1e-9) / 1e6:.1f}M pairs/s)")
print(f"TRB: {X_trb_raw.shape}  {t_trb:.1f}s  "
      f"({len(clono_trb) * N_PROTO / max(t_trb, 1e-9) / 1e6:.1f}M pairs/s)")
TRA: (8429, 3000)  0.4s  (19.9M pairs/s)
TRB: (8429, 3000)  0.5s  (18.2M pairs/s)
[7]:
# StandardScale embeddings → PCA (90% variance) → L2-normalise rows.
# L2 normalisation maps each point onto the unit hypersphere; the resulting
# Euclidean distances equal sqrt(2(1 - cos_sim)), bounded in [0, 2].
# In this space eps ~0.35 captures tight, convergent TCR motifs — matching
# the TCREmp paper convention.

from sklearn.preprocessing import normalize as l2normalize

scaler_tra = StandardScaler()
scaler_trb = StandardScaler()
X_tra_sc = scaler_tra.fit_transform(X_tra_raw)
X_trb_sc = scaler_trb.fit_transform(X_trb_raw)

# Full PCA for cumulative variance
pca_tra_full = PCA(random_state=SEED).fit(X_tra_sc)
pca_trb_full = PCA(random_state=SEED).fit(X_trb_sc)

cum_tra = np.cumsum(pca_tra_full.explained_variance_ratio_)
cum_trb = np.cumsum(pca_trb_full.explained_variance_ratio_)

VARIANCE_THRESHOLD = 0.90
n_tra = int(np.searchsorted(cum_tra, VARIANCE_THRESHOLD)) + 1
n_trb = int(np.searchsorted(cum_trb, VARIANCE_THRESHOLD)) + 1
print(f"TRA: {n_tra} PCs → {cum_tra[n_tra-1]*100:.1f}% variance")
print(f"TRB: {n_trb} PCs → {cum_trb[n_trb-1]*100:.1f}% variance")

# Refit PCA then L2-normalise rows → unit-sphere distances
pca_tra = PCA(n_components=n_tra, random_state=SEED)
pca_trb = PCA(n_components=n_trb, random_state=SEED)

X_tra_pca = l2normalize(pca_tra.fit_transform(X_tra_sc))   # used for clustering
X_trb_pca = l2normalize(pca_trb.fit_transform(X_trb_sc))

# Sanity-check: median 4-NN distance (should now be in [0, 2])
_nn4 = NearestNeighbors(n_neighbors=4)
_nn4.fit(X_tra_pca[:1000])
med_tra = float(np.median(_nn4.kneighbors(X_tra_pca[:1000])[0][:, -1]))
_nn4.fit(X_trb_pca[:1000])
med_trb = float(np.median(_nn4.kneighbors(X_trb_pca[:1000])[0][:, -1]))
print(f"Sample median 4-NN: TRA={med_tra:.3f}  TRB={med_trb:.3f}  (max possible: 2.0)")

# Cumulative variance plot
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
for ax, cum, n_comp, title, color in zip(
    axes, [cum_tra, cum_trb], [n_tra, n_trb],
    ['TRA', 'TRB'], ['#377eb8', '#e41a1c']
):
    max_show = min(200, len(cum))
    ax.plot(np.arange(1, max_show + 1), cum[:max_show] * 100, lw=1.5, color=color)
    ax.axhline(VARIANCE_THRESHOLD * 100, color='#888', ls='--', lw=1, label=f'{VARIANCE_THRESHOLD*100:.0f}% threshold')
    ax.axvline(n_comp, color='black', ls=':', lw=1.5, label=f'{n_comp} PCs selected')
    ax.set_xlim(0, max_show);  ax.set_ylim(0, 102)
    ax.set_xlabel('Number of principal components')
    ax.set_ylabel('Cumulative explained variance (%)')
    ax.set_title(f'TCREmp {title}: PCA variance')
    ax.legend(fontsize=9)

plt.tight_layout()
plt.show()
TRA: 45 PCs → 90.0% variance
TRB: 28 PCs → 90.4% variance
Sample median 4-NN: TRA=0.401  TRB=0.392  (max possible: 2.0)
../_images/notebooks_tcremp_vdjdb_analysis_7_1.png
[8]:
# Figure: PCA PC1 vs PC2 colored by V gene, J gene, and CDR3 length (TRA and TRB side-by-side).
# Shows how TCREmp embedding captures gene-usage and sequence-length structure.

def scatter_categorical(ax, xy, values, title, top_n=8, s=6, alpha=0.55):
    """Scatter plot colored by top-N categories; rest as 'other' (grey)."""
    counts = values.value_counts()
    top_cats = counts.index[:top_n].tolist()
    tab = plt.get_cmap('tab10')
    pal = {cat: tab(i) for i, cat in enumerate(top_cats)}
    pal['other'] = '#cccccc'
    cats = values.where(values.isin(top_cats), other='other')
    # plot 'other' first so focal categories are on top
    for cat in ['other'] + top_cats:
        mask = cats == cat
        if mask.sum() == 0:
            continue
        ax.scatter(xy[mask, 0], xy[mask, 1],
                   c=[pal[cat]], s=s, alpha=alpha,
                   label=cat if cat != 'other' else None, rasterized=True)
    ax.legend(loc='upper right', fontsize=7, markerscale=2,
              frameon=False, ncol=1, handlelength=1)
    ax.set_title(title, pad=4)
    ax.set_xlabel('PC 1', labelpad=2)
    ax.set_ylabel('PC 2', labelpad=2)
    ax.tick_params(labelsize=7)

def scatter_continuous(ax, xy, values, title, cmap='viridis', s=6, alpha=0.55):
    """Scatter plot colored by a continuous variable."""
    sc = ax.scatter(xy[:, 0], xy[:, 1], c=values, cmap=cmap,
                    s=s, alpha=alpha, rasterized=True)
    cbar = plt.colorbar(sc, ax=ax, pad=0.02, shrink=0.85)
    cbar.ax.tick_params(labelsize=7)
    ax.set_title(title, pad=4)
    ax.set_xlabel('PC 1', labelpad=2)
    ax.set_ylabel('PC 2', labelpad=2)
    ax.tick_params(labelsize=7)

fig, axes = plt.subplots(3, 2, figsize=(12, 14))

pairs_info = [
    # (PCA coords, df, feature, plot_type, top_n, title_suffix)
    (X_tra_pca, df_tra, 'v.segm', 'cat', 8, 'TRA — V gene'),
    (X_trb_pca, df_trb, 'v.segm', 'cat', 8, 'TRB — V gene'),
    (X_tra_pca, df_tra, 'j.segm', 'cat', 8, 'TRA — J gene'),
    (X_trb_pca, df_trb, 'j.segm', 'cat', 8, 'TRB — J gene'),
    (X_tra_pca, df_tra, 'cdr3_len', 'cont', None, 'TRA — CDR3 length'),
    (X_trb_pca, df_trb, 'cdr3_len', 'cont', None, 'TRB — CDR3 length'),
]

for ax, (xy, df, feat, ptype, top_n, title) in zip(axes.flatten(), pairs_info):
    if ptype == 'cat':
        scatter_categorical(ax, xy, df[feat], title, top_n=top_n)
    else:
        scatter_continuous(ax, xy, df[feat].values, title, cmap='plasma')

plt.suptitle('TCREmp PCA — structural features', fontsize=13, y=1.01)
plt.tight_layout()
plt.show()
../_images/notebooks_tcremp_vdjdb_analysis_8_0.png
[9]:
# Figure: PCA PC1 vs PC2 colored by epitope (top-11 focal + 'other' in grey).
# TRA and TRB side-by-side. Demonstrates antigen-specificity signal in embedding.

def scatter_epitopes(ax, xy, labels, title, s=8, alpha=0.6):
    """Scatter with EPITOPE_PALETTE: focal epitopes on top of 'other'."""
    # Plot 'other' first (background)
    mask_other = labels == 'other'
    ax.scatter(xy[mask_other, 0], xy[mask_other, 1],
               c=EPITOPE_PALETTE['other'], s=s, alpha=0.25, rasterized=True, label=None)
    # Plot focal epitopes
    for ep in FOCAL_EPITOPES:
        mask = labels == ep
        if mask.sum() == 0:
            continue
        ax.scatter(xy[mask, 0], xy[mask, 1],
                   c=EPITOPE_PALETTE[ep], s=s, alpha=alpha, rasterized=True, label=ep)
    ax.set_title(title, pad=4)
    ax.set_xlabel('PC 1', labelpad=2)
    ax.set_ylabel('PC 2', labelpad=2)
    ax.tick_params(labelsize=7)

fig, axes = plt.subplots(1, 2, figsize=(15, 6))

scatter_epitopes(axes[0], X_tra_pca, labels_tra, 'TRA — PCA by epitope')
scatter_epitopes(axes[1], X_trb_pca, labels_trb, 'TRB — PCA by epitope')

# Shared legend (focal epitopes only)
handles = [mpatches.Patch(color=EPITOPE_PALETTE[ep], label=ep) for ep in FOCAL_EPITOPES]
handles += [mpatches.Patch(color=EPITOPE_PALETTE['other'], alpha=0.4, label='other')]
axes[1].legend(handles=handles, loc='center left', bbox_to_anchor=(1.02, 0.5),
               frameon=False, fontsize=8, handlelength=1.2)

plt.suptitle('TCREmp PCA — antigen specificity', fontsize=13, y=1.01)
plt.tight_layout()
plt.show()
../_images/notebooks_tcremp_vdjdb_analysis_9_0.png
[10]:
# Kneedle eps selection + DBSCAN clustering on PCA coordinates (TRA and TRB separately).
# DBSCAN is run on PCA raw coordinates (NOT UMAP); UMAP is only for visualization.

K_NEIGHBORS = 4   # k-th neighbour used for eps estimation (following TCREmp default)
MIN_SAMPLES  = 3  # DBSCAN min_samples

def select_eps_kneedle(X_pca, k=4, title=''):
    """Compute sorted k-NN distances and find eps via kneedle method."""
    nn = NearestNeighbors(n_neighbors=k, metric='euclidean', algorithm='auto')
    nn.fit(X_pca)
    dists, _ = nn.kneighbors(X_pca)
    kth = np.sort(dists[:, -1])

    kl = KneeLocator(
        np.arange(len(kth)), kth,
        curve='convex', direction='increasing',
        interp_method='polynomial',
    )
    eps = float(kth[kl.knee]) if kl.knee is not None else float(np.percentile(kth, 10))
    return kth, kl, eps

# --- Kneedle for TRA ---
kth_tra, kl_tra, eps_tra = select_eps_kneedle(X_tra_pca, k=K_NEIGHBORS)
# --- Kneedle for TRB ---
kth_trb, kl_trb, eps_trb = select_eps_kneedle(X_trb_pca, k=K_NEIGHBORS)

print(f"Kneedle eps  TRA: {eps_tra:.4f}   TRB: {eps_trb:.4f}")

# ─── Kneedle plot ────────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
for ax, kth, kl, eps, title, color in zip(
    axes,
    [kth_tra, kth_trb], [kl_tra, kl_trb], [eps_tra, eps_trb],
    ['TRA', 'TRB'], ['#377eb8', '#e41a1c']
):
    ax.plot(kth, lw=1.2, color=color, label=f'{K_NEIGHBORS}-NN distance')
    if kl.knee is not None:
        ax.axvline(kl.knee, color='black', ls='--', lw=1.2)
        ax.axhline(eps, color='black', ls=':', lw=1.2,
                   label=f'knee → eps = {eps:.3f}')
    ax.set_xlabel('Points sorted by distance', labelpad=2)
    ax.set_ylabel(f'{K_NEIGHBORS}-NN distance', labelpad=2)
    ax.set_title(f'{title}: kneedle eps selection')
    ax.legend(fontsize=9)

plt.tight_layout()
plt.show()

# ─── DBSCAN ──────────────────────────────────────────────────────────────────
db_tra = DBSCAN(eps=eps_tra, min_samples=MIN_SAMPLES, metric='euclidean', n_jobs=-1)
db_trb = DBSCAN(eps=eps_trb, min_samples=MIN_SAMPLES, metric='euclidean', n_jobs=-1)

cluster_tra = db_tra.fit_predict(X_tra_pca)
cluster_trb = db_trb.fit_predict(X_trb_pca)

n_clust_tra = len(set(cluster_tra)) - (1 if -1 in cluster_tra else 0)
n_clust_trb = len(set(cluster_trb)) - (1 if -1 in cluster_trb else 0)

print(f"\nTRA: {n_clust_tra} clusters  noise={( cluster_tra==-1).mean():.2%}")
print(f"TRB: {n_clust_trb} clusters  noise={(cluster_trb==-1).mean():.2%}")
Kneedle eps  TRA: 0.4732   TRB: 0.3843
../_images/notebooks_tcremp_vdjdb_analysis_10_1.png

TRA: 298 clusters  noise=7.01%
TRB: 124 clusters  noise=4.73%
[11]:
# Compute UMAP on L2-normalised PCA coordinates (same space used for DBSCAN).
# TRA and TRB are computed separately to preserve chain identity.
import warnings

UMAP_PARAMS = dict(n_components=2, n_neighbors=30, min_dist=0.1,
                   metric='euclidean', random_state=SEED)

t0 = time.perf_counter()
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    umap_tra = umap.UMAP(**UMAP_PARAMS)
    X_umap_tra = umap_tra.fit_transform(X_tra_pca)
    umap_trb = umap.UMAP(**UMAP_PARAMS)
    X_umap_trb = umap_trb.fit_transform(X_trb_pca)

print(f"UMAP TRA: {X_umap_tra.shape}  TRB: {X_umap_trb.shape}  ({time.perf_counter()-t0:.1f}s)")

# ─── Helper for discrete cluster coloring ────────────────────────────────────
def scatter_clusters(ax, xy, cluster_ids, title, noise_color='#e0e0e0', s=7, alpha=0.55):
    """Color each cluster; noise (-1) drawn first in light grey."""
    unique_ids = sorted(c for c in np.unique(cluster_ids) if c != -1)
    cmap = plt.get_cmap('tab20')
    mask_noise = cluster_ids == -1
    ax.scatter(xy[mask_noise, 0], xy[mask_noise, 1],
               c=noise_color, s=s * 0.6, alpha=0.3, rasterized=True)
    for i, cid in enumerate(unique_ids):
        mask = cluster_ids == cid
        ax.scatter(xy[mask, 0], xy[mask, 1],
                   c=[cmap(i % 20)], s=s, alpha=alpha, rasterized=True)
    ax.set_title(title, pad=4)
    ax.set_xlabel('UMAP 1', labelpad=2)
    ax.set_ylabel('UMAP 2', labelpad=2)
    ax.tick_params(labelsize=7)

def scatter_epitopes_umap(ax, xy, labels, title):
    """Scatter colored by epitope category on arbitrary 2D embedding."""
    order_other = labels == 'other'
    ax.scatter(xy[order_other, 0], xy[order_other, 1],
               c=EPITOPE_PALETTE['other'], s=6, alpha=0.25, rasterized=True)
    for ep in FOCAL_EPITOPES:
        mask = labels == ep
        ax.scatter(xy[mask, 0], xy[mask, 1],
                   c=EPITOPE_PALETTE[ep], s=8, alpha=0.70, rasterized=True, label=ep)
    ax.set_title(title, pad=4)
    ax.set_xlabel('UMAP 1', labelpad=2)
    ax.set_ylabel('UMAP 2', labelpad=2)
    ax.tick_params(labelsize=7)

# ─── 2×2 figure: clusters (top) and epitopes (bottom) ───────────────────────
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

n_tra_cl = len(set(cluster_tra) - {-1})
n_trb_cl = len(set(cluster_trb) - {-1})
scatter_clusters(axes[0, 0], X_umap_tra, cluster_tra, f'TRA — UMAP by cluster  (n={n_tra_cl})')
scatter_clusters(axes[0, 1], X_umap_trb, cluster_trb, f'TRB — UMAP by cluster  (n={n_trb_cl})')

scatter_epitopes_umap(axes[1, 0], X_umap_tra, labels_tra, 'TRA — UMAP by epitope')
scatter_epitopes_umap(axes[1, 1], X_umap_trb, labels_trb, 'TRB — UMAP by epitope')

handles = [mpatches.Patch(color=EPITOPE_PALETTE[ep], label=ep) for ep in FOCAL_EPITOPES]
handles += [mpatches.Patch(color=EPITOPE_PALETTE['other'], alpha=0.4, label='other')]
axes[1, 1].legend(handles=handles, loc='center left', bbox_to_anchor=(1.02, 0.5),
                  frameon=False, fontsize=8, handlelength=1.2)

plt.suptitle('TCREmp UMAP — cluster structure and antigen specificity', fontsize=13, y=1.01)
plt.tight_layout()
plt.show()
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
UMAP TRA: (8429, 2)  TRB: (8429, 2)  (51.6s)
../_images/notebooks_tcremp_vdjdb_analysis_11_2.png
[12]:
# Clustering quality metrics: purity, retention, consistency.
#
# Definitions (cluster-level, excluding noise):
#   retention   = fraction of sequences assigned to any cluster (1 - noise fraction)
#   purity      = mean over clusters of (dominant-label fraction per cluster)
#   consistency = fraction of clusters where dominant-label fraction >= threshold

PURITY_THRESHOLD = 0.70

def clustering_metrics(labels, clusters, threshold=PURITY_THRESHOLD):
    mask = clusters != -1
    retention = float(mask.sum()) / len(clusters)
    cluster_ids = np.unique(clusters[mask])
    per_cluster_purity = []
    for cid in cluster_ids:
        cl_labels = labels[clusters == cid]
        dominant = pd.Series(cl_labels).value_counts().iloc[0] / len(cl_labels)
        per_cluster_purity.append(dominant)
    purity = float(np.mean(per_cluster_purity))
    consistency = float(np.mean([p >= threshold for p in per_cluster_purity]))
    return dict(
        n_clusters=len(cluster_ids),
        retention=retention,
        purity=purity,
        consistency=consistency,
    )

labels_tra_arr = np.array(labels_tra)
labels_trb_arr = np.array(labels_trb)

m_tra = clustering_metrics(labels_tra_arr, cluster_tra)
m_trb = clustering_metrics(labels_trb_arr, cluster_trb)

summary_df = pd.DataFrame([m_tra, m_trb], index=['TRA', 'TRB'])
summary_df['eps'] = [eps_tra, eps_trb]
summary_df['min_samples'] = MIN_SAMPLES
summary_df = summary_df[['eps', 'min_samples', 'n_clusters', 'retention', 'purity', 'consistency']]
summary_df['retention']    = summary_df['retention'].map('{:.3f}'.format)
summary_df['purity']       = summary_df['purity'].map('{:.3f}'.format)
summary_df['consistency']  = summary_df['consistency'].map('{:.3f}'.format)
summary_df['eps']          = summary_df['eps'].map('{:.4f}'.format)

print(f"Purity threshold for consistency: {PURITY_THRESHOLD:.0%}\n")
display(
    summary_df.style
    .background_gradient(cmap='YlOrRd', axis=0)
    .format(precision=4)
)

# Per-epitope purity on TRB (richer labels)
focal_mask_trb = labels_trb_arr != 'other'
if focal_mask_trb.sum() > 0:
    rows = []
    for ep in FOCAL_EPITOPES:
        ep_mask = labels_trb_arr == ep
        ep_clusters = cluster_trb[ep_mask]
        assigned = ep_clusters[ep_clusters != -1]
        if len(assigned) == 0:
            continue
        # purity: fraction of assigned sequences whose cluster is ep-dominant
        ep_purity_vals = []
        for cid in np.unique(assigned):
            cl_labels = labels_trb_arr[cluster_trb == cid]
            dominant_frac = pd.Series(cl_labels).value_counts().iloc[0] / len(cl_labels)
            ep_purity_vals.append(dominant_frac)
        rows.append({
            'epitope': ep,
            'n_seqs': int(ep_mask.sum()),
            'n_assigned': int((ep_clusters != -1).sum()),
            'retention': f"{(ep_clusters != -1).mean():.3f}",
            'mean_purity': f"{np.mean(ep_purity_vals):.3f}",
        })
    print("\nTRB per-focal-epitope breakdown:")
    print(pd.DataFrame(rows).to_string(index=False))
Purity threshold for consistency: 70%

  eps min_samples n_clusters retention purity consistency
TRA 0.4732 3 298 0.930 0.691 0.456
TRB 0.3843 3 124 0.953 0.713 0.540

TRB per-focal-epitope breakdown:
epitope  n_seqs  n_assigned retention mean_purity
    CIN     200         190     0.950       0.668
    ELA      50          48     0.960       0.689
    GIL     500         471     0.942       0.657
    GLC      90          89     0.989       0.663
    LLW     164         159     0.970       0.634
    NLV     500         476     0.952       0.684
    PKY      46          43     0.935       0.669
    SPR     346         327     0.945       0.667
    TFE      64          63     0.984       0.656
    TTD     353         336     0.952       0.666
    YLQ     294         283     0.963       0.682

Summary#

Dataset#

  • Source: VDJdb slim (2025-12-29), Homo sapiens

  • Filters: paired records only (complex.id 0), 10x Genomics excluded, CDR3 quality > 0

  • Clone-level balanced sampling: up to 500 clones per focal epitope (TRB-side label)

  • Final: 8 429 paired TRA/TRB clonotypes, 12 epitope categories (11 focal + other)

  • Focal epitopes: CIN · ELA · GIL · GLC · LLW · NLV · PKY · SPR · TFE · TTD · YLQ

Embedding & Dimensionality Reduction#

  • TCREmp (fixed_gap junction method, 1 000 prototypes) → 3 000-dimensional embedding per chain

  • PCA at 90 % variance threshold: 45 components (TRA) · 28 components (TRB)

  • PCA output L2-normalised per row → unit-hypersphere coordinates (Euclidean distance in this space equals √(2(1 − cos sim)), bounded in [0, 2]; calibrates eps scale to the TCREmp paper convention)

Clustering#

  • DBSCAN on L2-normalised PCA coordinates (min_samples = 3)

  • eps selected automatically via Kneedle on sorted 4-NN distances

Chain

eps (Kneedle)

Clusters

Retention

Purity

Consistency ≥ 70 %

TRA

0.473

298

93.0 %

69.1 %

45.6 %

TRB

0.384

124

95.3 %

71.3 %

54.0 %

  • Retention: fraction of sequences assigned to a cluster (not noise)

  • Purity: mean dominant-label fraction per cluster

  • Consistency: fraction of clusters with purity ≥ 70 %

UMAP Visualisation#

  • Fitted on L2-normalised PCA coordinates (not raw embeddings; clustering is PCA-based)

  • Clusters visible as discrete islands → strong convergence signal in both chains

  • TRB clusters map more cleanly onto epitope identity than TRA (expected: β-chain CDR3 is the primary MHC-contact loop)