Paired TCREmp on VDJdb Full#

This notebook analyzes paired TRA/TRB embeddings from vdjdb_full.txt.gz in two modes:

  1. Strict paired rows already containing both chains.

  2. The same human VDJdb rows after single-chain records are repaired with missing-chain imputation.

The paired embedding is a direct concatenation of chain-specific TCREmp embeddings.

[1]:
# Configure imports, plotting, randomness, and print environment versions.
import sys
import time
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
import sklearn
import umap
from IPython.display import display

try:
    import kneed
except ImportError:
    kneed = None

from mir.common.parser import VDJdbFullPairedParser
from mir.common.single_cell import build_tenx_sample_from_cell_clonotypes
from mir.common.single_cell_repair import impute_missing_chains
from mir.embedding.tcremp import PairedTCREmp
from mir.utils.embedding_diagnostics import analyze_embedding_dbscan
from mir.utils.notebook_assets import ensure_airr_benchmark, find_airr_benchmark_vdjdb_full, find_repo_root

SEED = 42
np.random.seed(SEED)

# Curation note: SLLMWITQV is excluded due to technique-specific enrichment
# toward a single degenerate TCR pattern rather than broad repertoire structure.
EXCLUDED_EPITOPES = {'SLLMWITQV'}

sns.set_theme(style='whitegrid', context='talk')
plt.rcParams['figure.dpi'] = 140
plt.rcParams['savefig.dpi'] = 300

print(f'Python: {sys.version.split()[0]}')
print(f'numpy: {np.__version__}')
print(f'polars: {pl.__version__}')
print(f'pandas: {pd.__version__}')
print(f'scikit-learn: {sklearn.__version__}')
print(f'umap-learn: {umap.__version__}')
print(f'kneed: {getattr(kneed, "__version__", "not-installed")}')
print(f'Excluded epitopes in this notebook: {sorted(EXCLUDED_EPITOPES)}')
/Users/mikesh/vcs/mirpy/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Python: 3.12.12
numpy: 1.26.4
polars: 1.40.1
pandas: 3.0.3
scikit-learn: 1.8.0
umap-learn: 0.5.12
kneed: 0.8.6
Excluded epitopes in this notebook: ['SLLMWITQV']
[2]:
# Load VDJdb full rows, build strict/imputed paired datasets, and show metadata access patterns.
repo_root = find_repo_root(Path.cwd())
dataset_root = ensure_airr_benchmark(repo_root=repo_root, allow_patterns=['vdjdb/**'])
vdjdb_full = find_airr_benchmark_vdjdb_full(dataset_root)
parser = VDJdbFullPairedParser()

t0 = time.perf_counter()
strict_cell_df, strict_metadata = parser.parse_cell_clonotypes_file(
    vdjdb_full,
    sample_id='vdjdb_full_human_strict',
    species='HomoSapiens',
    include_incomplete=False,
)
strict_sample = build_tenx_sample_from_cell_clonotypes(
    strict_cell_df,
    sample_id='vdjdb_full_human_strict',
    barcode_metadata=strict_metadata,
)
strict_runtime = time.perf_counter() - t0

t0 = time.perf_counter()
impute_input_df, impute_metadata = parser.parse_cell_clonotypes_file(
    vdjdb_full,
    sample_id='vdjdb_full_human_impute',
    species='HomoSapiens',
    include_incomplete=True,
)
imputed_cell_df = impute_missing_chains(impute_input_df)
imputed_sample = build_tenx_sample_from_cell_clonotypes(
    imputed_cell_df,
    sample_id='vdjdb_full_human_impute',
    barcode_metadata=impute_metadata,
)
imputed_runtime = time.perf_counter() - t0

strict_pairs = strict_sample.paired_locus_repertoires['TRA_TRB'].paired_clonotypes
imputed_pairs = imputed_sample.paired_locus_repertoires['TRA_TRB'].paired_clonotypes

print(f'VDJdb full file: {vdjdb_full}')
print(f'Strict rows:   {strict_cell_df.height:,} chains -> {len(strict_pairs):,} paired clonotypes ({strict_runtime:.2f}s)')
print(f'Imputed rows:  {imputed_cell_df.height:,} chains -> {len(imputed_pairs):,} paired clonotypes ({imputed_runtime:.2f}s)')
print(f'Added paired clonotypes after imputation: {len(imputed_pairs) - len(strict_pairs):,}')

# Demonstration 1: direct dictionary access by VDJdb record id (synthetic barcode).
for barcode in ['993', '1007']:
    meta = strict_sample.single_cell_repertoire.barcode_metadata.get(barcode, {})
    print(barcode, {k: meta.get(k, '') for k in ['antigen.epitope', 'antigen.gene', 'antigen.species', 'mhc.a', 'mhc.class']})

# Demonstration 2: tabular metadata access for joins/filters.
strict_meta_df = strict_sample.single_cell_repertoire.metadata_to_polars()
print(strict_meta_df.select(['barcode', 'vdjdb_record_id', 'antigen.epitope', 'mhc.a']).head(5))
VDJdb full file: /Users/mikesh/vcs/mirpy/notebooks/assets/large/airr_benchmark/vdjdb/vdjdb-2025-12-29/vdjdb_full.txt.gz
Strict rows:   160,626 chains -> 80,313 paired clonotypes (4.95s)
Imputed rows:  252,274 chains -> 126,137 paired clonotypes (305.57s)
Added paired clonotypes after imputation: 45,824
993 {'antigen.epitope': 'HPNGYKSLSTL', 'antigen.gene': 'NS1', 'antigen.species': 'InfluenzaB', 'mhc.a': 'HLA-B*07:02', 'mhc.class': 'MHCI'}
1007 {'antigen.epitope': 'RPIIRPATL', 'antigen.gene': 'NP', 'antigen.species': 'InfluenzaB', 'mhc.a': 'HLA-B*07:02', 'mhc.class': 'MHCI'}
shape: (5, 4)
┌─────────┬─────────────────┬─────────────────┬─────────────┐
│ barcode ┆ vdjdb_record_id ┆ antigen.epitope ┆ mhc.a       │
│ ---     ┆ ---             ┆ ---             ┆ ---         │
│ str     ┆ str             ┆ str             ┆ str         │
╞═════════╪═════════════════╪═════════════════╪═════════════╡
│ 10000   ┆ 10000           ┆ GILGFVFTL       ┆ HLA-A*02:01 │
│ 100000  ┆ 100000          ┆ NLVPMVATV       ┆ HLA-A*02:01 │
│ 100001  ┆ 100001          ┆ NLVPMVATV       ┆ HLA-A*02:01 │
│ 100002  ┆ 100002          ┆ NLVPMVATV       ┆ HLA-A*02:01 │
│ 100003  ┆ 100003          ┆ NLVPMVATV       ┆ HLA-A*02:01 │
└─────────┴─────────────────┴─────────────────┴─────────────┘
[3]:
# Convert paired repertoires to polars tables and build balanced strict/imputed subsets.
def paired_sample_to_table(sample):
    rows = []
    for pair in sample.paired_locus_repertoires['TRA_TRB'].paired_clonotypes:
        chains = {pair.clonotype1.locus: pair.clonotype1, pair.clonotype2.locus: pair.clonotype2}
        barcode = pair.pair_id.split('_', 1)[0]
        meta = sample.single_cell_repertoire.barcode_metadata.get(barcode, {})
        rows.append(
            {
                'pair_id': pair.pair_id,
                'barcode': barcode,
                'epitope': meta.get('antigen.epitope', ''),
                'antigen_gene': meta.get('antigen.gene', ''),
                'antigen_species': meta.get('antigen.species', ''),
                'mhc_a': meta.get('mhc.a', ''),
                'reference_id': meta.get('reference.id', ''),
                'method_identification': meta.get('method.identification', ''),
                'tra_v': chains['TRA'].v_gene,
                'tra_j': chains['TRA'].j_gene,
                'tra_cdr3': chains['TRA'].junction_aa,
                'trb_v': chains['TRB'].v_gene,
                'trb_j': chains['TRB'].j_gene,
                'trb_cdr3': chains['TRB'].junction_aa,
                'tra_len': len(chains['TRA'].junction_aa),
                'trb_len': len(chains['TRB'].junction_aa),
                'paired_cdr3_len': len(chains['TRA'].junction_aa) + len(chains['TRB'].junction_aa),
            }
        )
    return pl.DataFrame(rows)


def add_epitope_category(df, focal_epitopes):
    return df.with_columns(
        pl.when(pl.col('epitope').is_in(focal_epitopes))
        .then(pl.col('epitope'))
        .otherwise(pl.lit('other'))
        .alias('epitope_cat')
    )


def seeded_sample(df, n):
    if df.height <= n:
        return df
    return df.sample(n=n, with_replacement=False, shuffle=True, seed=SEED)


def select_balanced(df, focal_epitopes, sample_per_epitope=250, other_sample=500):
    parts = []
    for ep in focal_epitopes:
        parts.append(seeded_sample(df.filter(pl.col('epitope_cat') == ep), sample_per_epitope))
    parts.append(seeded_sample(df.filter(pl.col('epitope_cat') == 'other'), other_sample))
    return pl.concat([x for x in parts if x.height > 0], how='vertical')

strict_df = paired_sample_to_table(strict_sample)
imputed_df = paired_sample_to_table(imputed_sample)

# Exclude technique-driven outlier epitopes from focal modeling diagnostics.
strict_df = strict_df.filter(~pl.col('epitope').is_in(sorted(EXCLUDED_EPITOPES)))
imputed_df = imputed_df.filter(~pl.col('epitope').is_in(sorted(EXCLUDED_EPITOPES)))

FOCAL_EPITOPES = (
    strict_df.group_by('epitope')
    .len()
    .sort('len', descending=True)
    .head(10)
    .get_column('epitope')
    .to_list()
)
SAMPLE_PER_EPITOPE = 250
OTHER_SAMPLE = 500

strict_df = add_epitope_category(strict_df, FOCAL_EPITOPES)
imputed_df = add_epitope_category(imputed_df, FOCAL_EPITOPES)

strict_sel = select_balanced(strict_df, FOCAL_EPITOPES, SAMPLE_PER_EPITOPE, OTHER_SAMPLE)
imputed_sel = select_balanced(imputed_df, FOCAL_EPITOPES, SAMPLE_PER_EPITOPE, OTHER_SAMPLE)

print('Focal epitopes (SLL excluded):', FOCAL_EPITOPES)
print('Strict subset size:', strict_sel.height)
print(strict_sel.group_by('epitope_cat').len().sort('len', descending=True))
print()
print('Imputed subset size:', imputed_sel.height)
print(imputed_sel.group_by('epitope_cat').len().sort('len', descending=True))
display(strict_sel.group_by('epitope_cat').len().sort('len', descending=True).to_pandas())
Focal epitopes (SLL excluded): ['KLGGALQAK', 'NLVPMVATV', 'GILGFVFTL', 'RAKFKQLL', 'AVFDRKSDAK', 'GLCTLVAML', 'IVTDFSVIK', 'FLRGRAYGL', 'LLAGIGTVPI', 'YLQPRTFLL']
Strict subset size: 3000
shape: (11, 2)
┌─────────────┬─────┐
│ epitope_cat ┆ len │
│ ---         ┆ --- │
│ str         ┆ u32 │
╞═════════════╪═════╡
│ other       ┆ 500 │
│ NLVPMVATV   ┆ 250 │
│ AVFDRKSDAK  ┆ 250 │
│ GLCTLVAML   ┆ 250 │
│ YLQPRTFLL   ┆ 250 │
│ …           ┆ …   │
│ FLRGRAYGL   ┆ 250 │
│ RAKFKQLL    ┆ 250 │
│ IVTDFSVIK   ┆ 250 │
│ LLAGIGTVPI  ┆ 250 │
│ KLGGALQAK   ┆ 250 │
└─────────────┴─────┘

Imputed subset size: 3000
shape: (11, 2)
┌─────────────┬─────┐
│ epitope_cat ┆ len │
│ ---         ┆ --- │
│ str         ┆ u32 │
╞═════════════╪═════╡
│ other       ┆ 500 │
│ KLGGALQAK   ┆ 250 │
│ GILGFVFTL   ┆ 250 │
│ FLRGRAYGL   ┆ 250 │
│ NLVPMVATV   ┆ 250 │
│ …           ┆ …   │
│ IVTDFSVIK   ┆ 250 │
│ AVFDRKSDAK  ┆ 250 │
│ YLQPRTFLL   ┆ 250 │
│ LLAGIGTVPI  ┆ 250 │
│ GLCTLVAML   ┆ 250 │
└─────────────┴─────┘
epitope_cat len
0 other 500
1 RAKFKQLL 250
2 IVTDFSVIK 250
3 GLCTLVAML 250
4 KLGGALQAK 250
5 GILGFVFTL 250
6 AVFDRKSDAK 250
7 FLRGRAYGL 250
8 YLQPRTFLL 250
9 NLVPMVATV 250
10 LLAGIGTVPI 250
[4]:
# Embed balanced paired subsets and report per-record throughput (paired vs single-chain sum).
strict_pair_map = {pair.pair_id: pair for pair in strict_pairs}
imputed_pair_map = {pair.pair_id: pair for pair in imputed_pairs}

strict_selected_pairs = [strict_pair_map[pair_id] for pair_id in strict_sel.get_column('pair_id').to_list()]
imputed_selected_pairs = [imputed_pair_map[pair_id] for pair_id in imputed_sel.get_column('pair_id').to_list()]

N_PROTO = 500
model = PairedTCREmp.from_defaults(species='human', locus_pair='TRA_TRB', n_prototypes=N_PROTO, junction_method='fixed_gap')

t0 = time.perf_counter()
X_strict_raw = model.embed(strict_selected_pairs)
t_strict_embed = time.perf_counter() - t0

t0 = time.perf_counter()
X_imputed_raw = model.embed(imputed_selected_pairs)
t_imputed_embed = time.perf_counter() - t0

tra_only = [pair.clonotype1 if pair.clonotype1.locus == 'TRA' else pair.clonotype2 for pair in strict_selected_pairs]
trb_only = [pair.clonotype1 if pair.clonotype1.locus == 'TRB' else pair.clonotype2 for pair in strict_selected_pairs]

t0 = time.perf_counter()
X_tra_only = model.chain1_model.embed(tra_only)
t_tra = time.perf_counter() - t0

t0 = time.perf_counter()
X_trb_only = model.chain2_model.embed(trb_only)
t_trb = time.perf_counter() - t0

print(f'Strict paired embedding:  {X_strict_raw.shape}  {t_strict_embed:.3f}s  ({1e3 * t_strict_embed / len(strict_selected_pairs):.3f} ms/record)')
print(f'Imputed paired embedding: {X_imputed_raw.shape}  {t_imputed_embed:.3f}s  ({1e3 * t_imputed_embed / len(imputed_selected_pairs):.3f} ms/record)')
print(f'Strict TRA only: {t_tra:.3f}s  Strict TRB only: {t_trb:.3f}s  Paired/(TRA+TRB): {t_strict_embed / max(t_tra + t_trb, 1e-9):.3f}x')
Strict paired embedding:  (3000, 3000)  0.094s  (0.031 ms/record)
Imputed paired embedding: (3000, 3000)  0.090s  (0.030 ms/record)
Strict TRA only: 0.041s  Strict TRB only: 0.044s  Paired/(TRA+TRB): 1.102x
[5]:
# Compute PCA variance, kneedle eps, clustering metrics, and chain-level variants.
def analyze_embedding(X_raw, labels):
    analysis = analyze_embedding_dbscan(
        X_raw,
        labels,
        seed=SEED,
        pca_variance_threshold=0.90,
        min_samples=3,
        k_neighbors=4,
        consistency_threshold=0.70,
    )

    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        analysis['X_umap'] = umap.UMAP(
            n_components=2,
            n_neighbors=30,
            min_dist=0.10,
            metric='euclidean',
            random_state=SEED,
        ).fit_transform(analysis['X_pca'])

    return analysis


strict_labels = strict_sel.get_column('epitope_cat').to_numpy()
imputed_labels = imputed_sel.get_column('epitope_cat').to_numpy()

strict_analysis = analyze_embedding(X_strict_raw, strict_labels)
imputed_analysis = analyze_embedding(X_imputed_raw, imputed_labels)

# Chain-only analyses on strict paired records to identify which chain drives SLL behavior.
strict_tra_analysis = analyze_embedding(X_tra_only, strict_labels)
strict_trb_analysis = analyze_embedding(X_trb_only, strict_labels)

summary_df = pl.DataFrame(
    [
        {
            'mode': 'strict-paired',
            'n_pairs': len(strict_selected_pairs),
            'pcs_90pct': strict_analysis['n_comp'],
            'eps': strict_analysis['eps'],
            'clusters': strict_analysis['n_clusters'],
            'retention': strict_analysis['retention'],
            'purity': strict_analysis['purity'],
            'consistency_70': strict_analysis['consistency'],
            'median_4nn': strict_analysis['median_4nn'],
        },
        {
            'mode': 'imputed-paired',
            'n_pairs': len(imputed_selected_pairs),
            'pcs_90pct': imputed_analysis['n_comp'],
            'eps': imputed_analysis['eps'],
            'clusters': imputed_analysis['n_clusters'],
            'retention': imputed_analysis['retention'],
            'purity': imputed_analysis['purity'],
            'consistency_70': imputed_analysis['consistency'],
            'median_4nn': imputed_analysis['median_4nn'],
        },
        {
            'mode': 'strict-TRA-only',
            'n_pairs': len(strict_selected_pairs),
            'pcs_90pct': strict_tra_analysis['n_comp'],
            'eps': strict_tra_analysis['eps'],
            'clusters': strict_tra_analysis['n_clusters'],
            'retention': strict_tra_analysis['retention'],
            'purity': strict_tra_analysis['purity'],
            'consistency_70': strict_tra_analysis['consistency'],
            'median_4nn': strict_tra_analysis['median_4nn'],
        },
        {
            'mode': 'strict-TRB-only',
            'n_pairs': len(strict_selected_pairs),
            'pcs_90pct': strict_trb_analysis['n_comp'],
            'eps': strict_trb_analysis['eps'],
            'clusters': strict_trb_analysis['n_clusters'],
            'retention': strict_trb_analysis['retention'],
            'purity': strict_trb_analysis['purity'],
            'consistency_70': strict_trb_analysis['consistency'],
            'median_4nn': strict_trb_analysis['median_4nn'],
        },
    ]
)
print(summary_df)
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
shape: (4, 9)
┌────────────┬─────────┬───────────┬──────────┬───┬───────────┬──────────┬────────────┬────────────┐
│ mode       ┆ n_pairs ┆ pcs_90pct ┆ eps      ┆ … ┆ retention ┆ purity   ┆ consistenc ┆ median_4nn │
│ ---        ┆ ---     ┆ ---       ┆ ---      ┆   ┆ ---       ┆ ---      ┆ y_70       ┆ ---        │
│ str        ┆ i64     ┆ i64       ┆ f64      ┆   ┆ f64       ┆ f64      ┆ ---        ┆ f64        │
│            ┆         ┆           ┆          ┆   ┆           ┆          ┆ f64        ┆            │
╞════════════╪═════════╪═══════════╪══════════╪═══╪═══════════╪══════════╪════════════╪════════════╡
│ strict-pai ┆ 3000    ┆ 62        ┆ 0.717217 ┆ … ┆ 0.678333  ┆ 0.564754 ┆ 0.259615   ┆ 0.733947   │
│ red        ┆         ┆           ┆          ┆   ┆           ┆          ┆            ┆            │
│ imputed-pa ┆ 3000    ┆ 62        ┆ 0.738993 ┆ … ┆ 0.745333  ┆ 0.508441 ┆ 0.202703   ┆ 0.751182   │
│ ired       ┆         ┆           ┆          ┆   ┆           ┆          ┆            ┆            │
│ strict-TRA ┆ 3000    ┆ 40        ┆ 0.318535 ┆ … ┆ 0.597     ┆ 0.540373 ┆ 0.246512   ┆ 0.364602   │
│ -only      ┆         ┆           ┆          ┆   ┆           ┆          ┆            ┆            │
│ strict-TRB ┆ 3000    ┆ 25        ┆ 0.28915  ┆ … ┆ 0.635667  ┆ 0.534971 ┆ 0.221053   ┆ 0.299993   │
│ -only      ┆         ┆           ┆          ┆   ┆           ┆          ┆            ┆            │
└────────────┴─────────┴───────────┴──────────┴───┴───────────┴──────────┴────────────┴────────────┘
[6]:
# Show styled summary table for readability.
display(
    summary_df.to_pandas().style
    .background_gradient(subset=['purity', 'consistency_70', 'retention'], cmap='YlGnBu')
    .format(precision=3)
)
  mode n_pairs pcs_90pct eps clusters retention purity consistency_70 median_4nn
0 strict-paired 3000 62 0.717 104 0.678 0.565 0.260 0.734
1 imputed-paired 3000 62 0.739 74 0.745 0.508 0.203 0.751
2 strict-TRA-only 3000 40 0.319 215 0.597 0.540 0.247 0.365
3 strict-TRB-only 3000 25 0.289 190 0.636 0.535 0.221 0.300
[7]:
# Plot PCA variance explained and kneedle curves for strict vs imputed paired embeddings.
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

for ax, analysis, title, color in [
    (axes[0, 0], strict_analysis, 'Strict paired', '#1f77b4'),
    (axes[0, 1], imputed_analysis, 'Imputed paired', '#d62728'),
]:
    max_show = min(200, len(analysis['cum']))
    ax.plot(np.arange(1, max_show + 1), analysis['cum'][:max_show] * 100, color=color, lw=1.8)
    ax.axhline(90, color='black', ls='--', lw=1.0)
    ax.axvline(analysis['n_comp'], color='black', ls=':', lw=1.0)
    ax.set_title(f'{title} PCA cumulative variance')
    ax.set_xlabel('PC index')
    ax.set_ylabel('Cumulative variance (%)')
    ax.set_ylim(0, 101)

for ax, analysis, title, color in [
    (axes[1, 0], strict_analysis, 'Strict paired', '#1f77b4'),
    (axes[1, 1], imputed_analysis, 'Imputed paired', '#d62728'),
]:
    ax.plot(analysis['kth'], color=color, lw=1.4)
    if analysis['knee_idx'] is not None:
        ax.axvline(analysis['knee_idx'], color='black', ls='--', lw=1.0)
    ax.axhline(analysis['eps'], color='black', ls=':', lw=1.0)
    ax.set_title(f'{title} kneedle (4-NN)')
    ax.set_xlabel('Sorted point index')
    ax.set_ylabel('4-NN distance')

plt.tight_layout()
plt.show()

print('Strict paired:  eps={:.4f}, PCs@90%={}, clusters={}, retention={:.3f}, purity={:.3f}, consistency={:.3f}'.format(
    strict_analysis['eps'], strict_analysis['n_comp'], strict_analysis['n_clusters'], strict_analysis['retention'], strict_analysis['purity'], strict_analysis['consistency']
))
print('Imputed paired: eps={:.4f}, PCs@90%={}, clusters={}, retention={:.3f}, purity={:.3f}, consistency={:.3f}'.format(
    imputed_analysis['eps'], imputed_analysis['n_comp'], imputed_analysis['n_clusters'], imputed_analysis['retention'], imputed_analysis['purity'], imputed_analysis['consistency']
))
../_images/notebooks_tcremp_vdjdb_analysis_paired_7_0.png
Strict paired:  eps=0.7172, PCs@90%=62, clusters=104, retention=0.678, purity=0.565, consistency=0.260
Imputed paired: eps=0.7390, PCs@90%=62, clusters=74, retention=0.745, purity=0.508, consistency=0.203
[8]:
# Visualize paired embeddings by epitope, by cluster, and by chain-level features.
palette_keys = FOCAL_EPITOPES + ['other']
palette_values = sns.color_palette('tab20', n_colors=len(FOCAL_EPITOPES)) + [(0.80, 0.80, 0.80)]
PALETTE = dict(zip(palette_keys, palette_values))


def scatter_epitopes(ax, coords, labels, title, xlabel, ylabel):
    labels = np.asarray(labels)
    mask_other = labels == 'other'
    ax.scatter(coords[mask_other, 0], coords[mask_other, 1], s=9, alpha=0.20, color=PALETTE['other'], rasterized=True)
    for epitope in FOCAL_EPITOPES:
        mask = labels == epitope
        if mask.sum() == 0:
            continue
        ax.scatter(coords[mask, 0], coords[mask, 1], s=14, alpha=0.78, color=PALETTE[epitope], rasterized=True, label=epitope)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel(xlabel, fontsize=10)
    ax.set_ylabel(ylabel, fontsize=10)


def scatter_clusters(ax, coords, clusters, title, xlabel, ylabel):
    cluster_labels = np.array([f'cluster_{c}' if c != -1 else 'noise' for c in clusters])
    uniq, counts = np.unique(cluster_labels, return_counts=True)
    order = np.argsort(-counts)
    top = set(uniq[order[:10]])
    cluster_disp = np.array([cl if cl in top or cl == 'noise' else 'other_cluster' for cl in cluster_labels])

    palette = sns.color_palette('tab20b', n_colors=len(np.unique(cluster_disp)))
    for i, cl in enumerate(np.unique(cluster_disp)):
        mask = cluster_disp == cl
        ax.scatter(coords[mask, 0], coords[mask, 1], s=12, alpha=0.75, color=palette[i], rasterized=True, label=cl)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel(xlabel, fontsize=10)
    ax.set_ylabel(ylabel, fontsize=10)


fig, axes = plt.subplots(3, 2, figsize=(16, 16), constrained_layout=True)
scatter_epitopes(axes[0, 0], strict_analysis['X_pca'][:, :2], strict_labels, 'Strict paired PCA by epitope', 'PC 1', 'PC 2')
scatter_epitopes(axes[0, 1], strict_analysis['X_umap'], strict_labels, 'Strict paired UMAP by epitope', 'UMAP 1', 'UMAP 2')
scatter_clusters(axes[1, 0], strict_analysis['X_pca'][:, :2], strict_analysis['clusters'], 'Strict paired PCA by cluster', 'PC 1', 'PC 2')
scatter_clusters(axes[1, 1], strict_analysis['X_umap'], strict_analysis['clusters'], 'Strict paired UMAP by cluster', 'UMAP 1', 'UMAP 2')
scatter_epitopes(axes[2, 0], imputed_analysis['X_umap'], imputed_labels, 'Imputed paired UMAP by epitope', 'UMAP 1', 'UMAP 2')
scatter_clusters(axes[2, 1], imputed_analysis['X_umap'], imputed_analysis['clusters'], 'Imputed paired UMAP by cluster', 'UMAP 1', 'UMAP 2')

for r in range(3):
    for c in range(2):
        if axes[r, c].legend_ is not None:
            axes[r, c].legend_.remove()

handles_e, labels_e = axes[0, 1].get_legend_handles_labels()
handles_c, labels_c = axes[1, 1].get_legend_handles_labels()
fig.legend(handles_e, labels_e, loc='center left', bbox_to_anchor=(1.01, 0.72), frameon=False, title='Epitope', fontsize=9)
fig.legend(handles_c, labels_c, loc='center left', bbox_to_anchor=(1.01, 0.33), frameon=False, title='Cluster', fontsize=8)
fig.suptitle('Paired VDJdb embeddings after SLL exclusion: epitope and cluster colorings', fontsize=16, y=1.01)
plt.show()


def top_categories(values, top_n=8):
    uniq, counts = np.unique(values, return_counts=True)
    order = np.argsort(-counts)
    top = set(uniq[order[:top_n]])
    return np.array([v if v in top else 'other' for v in values])


strict_tra_v = top_categories(strict_sel.get_column('tra_v').to_numpy())
strict_trb_v = top_categories(strict_sel.get_column('trb_v').to_numpy())
strict_tra_j = top_categories(strict_sel.get_column('tra_j').to_numpy())
strict_trb_j = top_categories(strict_sel.get_column('trb_j').to_numpy())
strict_tra_len = strict_sel.get_column('tra_len').to_numpy()
strict_trb_len = strict_sel.get_column('trb_len').to_numpy()

fig, axes = plt.subplots(2, 3, figsize=(18, 11), constrained_layout=True)
for ax, values, title in [
    (axes[0, 0], strict_tra_v, 'Strict paired PCA color: TRA V'),
    (axes[0, 1], strict_tra_j, 'Strict paired PCA color: TRA J'),
    (axes[0, 2], strict_tra_len, 'Strict paired PCA color: TRA CDR3 length'),
    (axes[1, 0], strict_trb_v, 'Strict paired PCA color: TRB V'),
    (axes[1, 1], strict_trb_j, 'Strict paired PCA color: TRB J'),
    (axes[1, 2], strict_trb_len, 'Strict paired PCA color: TRB CDR3 length'),
]:
    if values.dtype.kind in {'U', 'O'}:
        cats = np.unique(values)
        cmap = sns.color_palette('tab10', n_colors=len(cats))
        for c_idx, cat in enumerate(cats):
            mask = values == cat
            ax.scatter(strict_analysis['X_pca'][mask, 0], strict_analysis['X_pca'][mask, 1], s=10, alpha=0.68, color=cmap[c_idx], rasterized=True, label=cat if c_idx < 8 else None)
    else:
        sc = ax.scatter(strict_analysis['X_pca'][:, 0], strict_analysis['X_pca'][:, 1], c=values, cmap='viridis', s=10, alpha=0.68, rasterized=True)
        plt.colorbar(sc, ax=ax, fraction=0.046, pad=0.04)
    ax.set_title(title, fontsize=11)
    ax.set_xlabel('PC 1', fontsize=10)
    ax.set_ylabel('PC 2', fontsize=10)

plt.show()
../_images/notebooks_tcremp_vdjdb_analysis_paired_8_0.png
../_images/notebooks_tcremp_vdjdb_analysis_paired_8_1.png
[9]:
# Diagnose outliers after explicit SLL exclusion and report FLR enrichment.
TARGET_EPITOPE = 'FLRGRAYGL'
PHAGE_REFERENCE = 'PMID:40498839'


def epitope_distance_profile(coords, labels):
    global_centroid = coords.mean(axis=0)
    rows = []
    for ep in np.unique(labels):
        mask = labels == ep
        ep_coords = coords[mask]
        ep_centroid = ep_coords.mean(axis=0)
        intra = np.linalg.norm(ep_coords - ep_centroid, axis=1)
        inter = np.linalg.norm(ep_coords - global_centroid, axis=1)
        rows.append(
            {
                'epitope': ep,
                'n': int(mask.sum()),
                'mean_intra_dist': float(intra.mean()),
                'mean_global_dist': float(inter.mean()),
            }
        )
    return pl.DataFrame(rows).sort('mean_global_dist', descending=True)


def is_phage_display_expr() -> pl.Expr:
    method = pl.col('method_identification').fill_null('').str.to_lowercase()
    ref = pl.col('reference_id').fill_null('')
    return method.str.contains('phage') | (ref == PHAGE_REFERENCE)


def enriched_vj(table, target_epitope, chain='tra', top_n=8):
    v_col = f'{chain}_v'
    j_col = f'{chain}_j'
    target = table.filter(pl.col('epitope') == target_epitope)
    background = table.filter(pl.col('epitope') != target_epitope)

    if target.height == 0 or background.height == 0:
        return pl.DataFrame({'v': [], 'j': [], 'target_n': [], 'target_frac': [], 'background_frac': [], 'log2_enrichment': []})

    target_counts = target.group_by([v_col, j_col]).len().rename({'len': 'target_n'})
    background_counts = background.group_by([v_col, j_col]).len().rename({'len': 'background_n'})

    target_total = float(target.height)
    background_total = float(background.height)

    return (
        target_counts.join(background_counts, on=[v_col, j_col], how='left')
        .with_columns(
            pl.col('background_n').fill_null(0),
            (pl.col('target_n') / target_total).alias('target_frac'),
            (pl.col('background_n') / background_total).alias('background_frac'),
        )
        .with_columns(
            ((pl.col('target_frac') + 1e-6) / (pl.col('background_frac') + 1e-6)).log(base=2).alias('log2_enrichment')
        )
        .sort(['log2_enrichment', 'target_n'], descending=[True, True])
        .head(top_n)
        .rename({v_col: 'v', j_col: 'j'})
        .select(['v', 'j', 'target_n', 'target_frac', 'background_frac', 'log2_enrichment'])
    )


strict_profile_paired = epitope_distance_profile(strict_analysis['X_pca'], strict_labels)
strict_profile_tra = epitope_distance_profile(strict_tra_analysis['X_pca'], strict_labels)
strict_profile_trb = epitope_distance_profile(strict_trb_analysis['X_pca'], strict_labels)

print('Outlier ranking in strict paired PCA (higher mean_global_dist = more outlying):')
print(strict_profile_paired.head(8))
print('\nOutlier ranking in strict TRA-only PCA:')
print(strict_profile_tra.head(8))
print('\nOutlier ranking in strict TRB-only PCA:')
print(strict_profile_trb.head(8))

flr_rank_paired = strict_profile_paired.with_row_index().filter(pl.col('epitope') == TARGET_EPITOPE)[0, 'index'] + 1
flr_rank_tra = strict_profile_tra.with_row_index().filter(pl.col('epitope') == TARGET_EPITOPE)[0, 'index'] + 1
flr_rank_trb = strict_profile_trb.with_row_index().filter(pl.col('epitope') == TARGET_EPITOPE)[0, 'index'] + 1

print(f"\nFLR outlier rank: paired={flr_rank_paired}, TRA-only={flr_rank_tra}, TRB-only={flr_rank_trb}")
print('\nTop enriched TRA V/J for FLRGRAYGL:')
print(enriched_vj(strict_df, 'FLRGRAYGL', chain='tra', top_n=8))
print('\nTop enriched TRB V/J for FLRGRAYGL:')
print(enriched_vj(strict_df, 'FLRGRAYGL', chain='trb', top_n=8))
Outlier ranking in strict paired PCA (higher mean_global_dist = more outlying):
shape: (8, 4)
┌────────────┬─────┬─────────────────┬──────────────────┐
│ epitope    ┆ n   ┆ mean_intra_dist ┆ mean_global_dist │
│ ---        ┆ --- ┆ ---             ┆ ---              │
│ str        ┆ i64 ┆ f64             ┆ f64              │
╞════════════╪═════╪═════════════════╪══════════════════╡
│ IVTDFSVIK  ┆ 250 ┆ 0.982779        ┆ 1.000975         │
│ AVFDRKSDAK ┆ 250 ┆ 0.993868        ┆ 1.000735         │
│ other      ┆ 500 ┆ 0.993849        ┆ 1.000527         │
│ NLVPMVATV  ┆ 250 ┆ 0.993313        ┆ 1.000456         │
│ KLGGALQAK  ┆ 250 ┆ 0.993967        ┆ 1.000437         │
│ LLAGIGTVPI ┆ 250 ┆ 0.92092         ┆ 1.000204         │
│ GLCTLVAML  ┆ 250 ┆ 0.941127        ┆ 0.999622         │
│ FLRGRAYGL  ┆ 250 ┆ 0.341326        ┆ 0.998977         │
└────────────┴─────┴─────────────────┴──────────────────┘

Outlier ranking in strict TRA-only PCA:
shape: (8, 4)
┌────────────┬─────┬─────────────────┬──────────────────┐
│ epitope    ┆ n   ┆ mean_intra_dist ┆ mean_global_dist │
│ ---        ┆ --- ┆ ---             ┆ ---              │
│ str        ┆ i64 ┆ f64             ┆ f64              │
╞════════════╪═════╪═════════════════╪══════════════════╡
│ FLRGRAYGL  ┆ 250 ┆ 0.34302         ┆ 1.019087         │
│ GILGFVFTL  ┆ 250 ┆ 0.937323        ┆ 1.003035         │
│ AVFDRKSDAK ┆ 250 ┆ 0.990336        ┆ 1.00013          │
│ IVTDFSVIK  ┆ 250 ┆ 0.976467        ┆ 0.999805         │
│ NLVPMVATV  ┆ 250 ┆ 0.990112        ┆ 0.999394         │
│ other      ┆ 500 ┆ 0.9935          ┆ 0.999081         │
│ KLGGALQAK  ┆ 250 ┆ 0.99339         ┆ 0.999041         │
│ RAKFKQLL   ┆ 250 ┆ 0.944118        ┆ 0.997669         │
└────────────┴─────┴─────────────────┴──────────────────┘

Outlier ranking in strict TRB-only PCA:
shape: (8, 4)
┌────────────┬─────┬─────────────────┬──────────────────┐
│ epitope    ┆ n   ┆ mean_intra_dist ┆ mean_global_dist │
│ ---        ┆ --- ┆ ---             ┆ ---              │
│ str        ┆ i64 ┆ f64             ┆ f64              │
╞════════════╪═════╪═════════════════╪══════════════════╡
│ LLAGIGTVPI ┆ 250 ┆ 0.98697         ┆ 1.007295         │
│ GLCTLVAML  ┆ 250 ┆ 0.941345        ┆ 1.005727         │
│ IVTDFSVIK  ┆ 250 ┆ 0.984157        ┆ 1.003129         │
│ other      ┆ 500 ┆ 0.994459        ┆ 1.00295          │
│ KLGGALQAK  ┆ 250 ┆ 0.994151        ┆ 1.002048         │
│ AVFDRKSDAK ┆ 250 ┆ 0.997422        ┆ 1.00191          │
│ NLVPMVATV  ┆ 250 ┆ 0.996075        ┆ 1.001338         │
│ YLQPRTFLL  ┆ 250 ┆ 0.955683        ┆ 0.997374         │
└────────────┴─────┴─────────────────┴──────────────────┘

FLR outlier rank: paired=8, TRA-only=1, TRB-only=11

Top enriched TRA V/J for FLRGRAYGL:
shape: (8, 6)
┌─────────────┬───────────┬──────────┬─────────────┬─────────────────┬─────────────────┐
│ v           ┆ j         ┆ target_n ┆ target_frac ┆ background_frac ┆ log2_enrichment │
│ ---         ┆ ---       ┆ ---      ┆ ---         ┆ ---             ┆ ---             │
│ str         ┆ str       ┆ u32      ┆ f64         ┆ f64             ┆ f64             │
╞═════════════╪═══════════╪══════════╪═════════════╪═════════════════╪═════════════════╡
│ TRAV26-2*01 ┆ TRAJ52*01 ┆ 456      ┆ 0.820144    ┆ 0.001399        ┆ 9.194143        │
│ TRAV4*01    ┆ TRAJ52*01 ┆ 3        ┆ 0.005396    ┆ 0.00006         ┆ 6.467971        │
│ TRAV3*01    ┆ TRAJ30*01 ┆ 28       ┆ 0.05036     ┆ 0.000939        ┆ 5.742825        │
│ TRAV30*01   ┆ TRAJ29*01 ┆ 4        ┆ 0.007194    ┆ 0.00016         ┆ 5.482771        │
│ TRAV2*01    ┆ TRAJ13*01 ┆ 4        ┆ 0.007194    ┆ 0.0002          ┆ 5.162637        │
│ TRAV26-1*01 ┆ TRAJ44*01 ┆ 1        ┆ 0.001799    ┆ 0.00008         ┆ 4.474434        │
│ TRAV22*01   ┆ TRAJ13*01 ┆ 1        ┆ 0.001799    ┆ 0.00012         ┆ 3.895424        │
│ TRAV12-2*01 ┆ TRAJ27*01 ┆ 6        ┆ 0.010791    ┆ 0.000879        ┆ 3.61559         │
└─────────────┴───────────┴──────────┴─────────────┴─────────────────┴─────────────────┘

Top enriched TRB V/J for FLRGRAYGL:
shape: (8, 6)
┌────────────┬────────────┬──────────┬─────────────┬─────────────────┬─────────────────┐
│ v          ┆ j          ┆ target_n ┆ target_frac ┆ background_frac ┆ log2_enrichment │
│ ---        ┆ ---        ┆ ---      ┆ ---         ┆ ---             ┆ ---             │
│ str        ┆ str        ┆ u32      ┆ f64         ┆ f64             ┆ f64             │
╞════════════╪════════════╪══════════╪═════════════╪═════════════════╪═════════════════╡
│ TRBV6S3*01 ┆ TRBJ2-1*01 ┆ 5        ┆ 0.008993    ┆ null            ┆ null            │
│ TRBV7-8*01 ┆ TRBJ2-7*01 ┆ 451      ┆ 0.811151    ┆ 0.002978        ┆ 8.088897        │
│ TRBV6-2*01 ┆ TRBJ2-4*01 ┆ 1        ┆ 0.001799    ┆ 0.00002         ┆ 6.421936        │
│ TRBV6-4*01 ┆ TRBJ1-6*01 ┆ 5        ┆ 0.008993    ┆ 0.00022         ┆ 5.347675        │
│ TRBV7-6*01 ┆ TRBJ2-1*01 ┆ 28       ┆ 0.05036     ┆ 0.002459        ┆ 4.355847        │
│ TRBV7-6*01 ┆ TRBJ2-7*01 ┆ 10       ┆ 0.017986    ┆ 0.002099        ┆ 3.09864         │
│ TRBV6-1*01 ┆ TRBJ1-4*01 ┆ 2        ┆ 0.003597    ┆ 0.00068         ┆ 2.402382        │
│ TRBV2*01   ┆ TRBJ1-2*01 ┆ 6        ┆ 0.010791    ┆ 0.002299        ┆ 2.230543        │
└────────────┴────────────┴──────────┴─────────────┴─────────────────┴─────────────────┘
[10]:
# Print compact exclusion and FLR summary for downstream reporting.
vdjdb_all = pl.read_csv(vdjdb_full, separator='\t', null_values=['', 'NA'], ignore_errors=True)
sll_all = vdjdb_all.filter(pl.col('antigen.epitope') == 'SLLMWITQV')
phage_all_n = sll_all.filter(
    pl.col('method.identification').cast(pl.Utf8).fill_null('').str.to_lowercase().str.contains('phage')
    | (pl.col('reference.id').cast(pl.Utf8).fill_null('') == PHAGE_REFERENCE)
).height
print({'sll_total_vdjdb_all': sll_all.height, 'sll_phage_vdjdb_all': phage_all_n})
print('SLLMWITQV kept in strict_df after exclusion:', strict_df.filter(pl.col('epitope') == 'SLLMWITQV').height)

flr_tra_top3 = enriched_vj(strict_df, 'FLRGRAYGL', chain='tra', top_n=3)
flr_trb_top3 = enriched_vj(strict_df, 'FLRGRAYGL', chain='trb', top_n=3)

print('Top FLR TRA V/J:\n', flr_tra_top3)
print('Top FLR TRB V/J:\n', flr_trb_top3)
display(flr_tra_top3.to_pandas().style.background_gradient(subset=['log2_enrichment'], cmap='YlGnBu').format(precision=3))
display(flr_trb_top3.to_pandas().style.background_gradient(subset=['log2_enrichment'], cmap='YlGnBu').format(precision=3))
{'sll_total_vdjdb_all': 29727, 'sll_phage_vdjdb_all': 29715}
SLLMWITQV kept in strict_df after exclusion: 0
Top FLR TRA V/J:
 shape: (3, 6)
┌─────────────┬───────────┬──────────┬─────────────┬─────────────────┬─────────────────┐
│ v           ┆ j         ┆ target_n ┆ target_frac ┆ background_frac ┆ log2_enrichment │
│ ---         ┆ ---       ┆ ---      ┆ ---         ┆ ---             ┆ ---             │
│ str         ┆ str       ┆ u32      ┆ f64         ┆ f64             ┆ f64             │
╞═════════════╪═══════════╪══════════╪═════════════╪═════════════════╪═════════════════╡
│ TRAV26-2*01 ┆ TRAJ52*01 ┆ 456      ┆ 0.820144    ┆ 0.001399        ┆ 9.194143        │
│ TRAV4*01    ┆ TRAJ52*01 ┆ 3        ┆ 0.005396    ┆ 0.00006         ┆ 6.467971        │
│ TRAV3*01    ┆ TRAJ30*01 ┆ 28       ┆ 0.05036     ┆ 0.000939        ┆ 5.742825        │
└─────────────┴───────────┴──────────┴─────────────┴─────────────────┴─────────────────┘
Top FLR TRB V/J:
 shape: (3, 6)
┌────────────┬────────────┬──────────┬─────────────┬─────────────────┬─────────────────┐
│ v          ┆ j          ┆ target_n ┆ target_frac ┆ background_frac ┆ log2_enrichment │
│ ---        ┆ ---        ┆ ---      ┆ ---         ┆ ---             ┆ ---             │
│ str        ┆ str        ┆ u32      ┆ f64         ┆ f64             ┆ f64             │
╞════════════╪════════════╪══════════╪═════════════╪═════════════════╪═════════════════╡
│ TRBV6S3*01 ┆ TRBJ2-1*01 ┆ 5        ┆ 0.008993    ┆ null            ┆ null            │
│ TRBV7-8*01 ┆ TRBJ2-7*01 ┆ 451      ┆ 0.811151    ┆ 0.002978        ┆ 8.088897        │
│ TRBV6-2*01 ┆ TRBJ2-4*01 ┆ 1        ┆ 0.001799    ┆ 0.00002         ┆ 6.421936        │
└────────────┴────────────┴──────────┴─────────────┴─────────────────┴─────────────────┘
  v j target_n target_frac background_frac log2_enrichment
0 TRAV26-2*01 TRAJ52*01 456 0.820 0.001 9.194
1 TRAV4*01 TRAJ52*01 3 0.005 0.000 6.468
2 TRAV3*01 TRAJ30*01 28 0.050 0.001 5.743
  v j target_n target_frac background_frac log2_enrichment
0 TRBV6S3*01 TRBJ2-1*01 5 0.009 nan nan
1 TRBV7-8*01 TRBJ2-7*01 451 0.811 0.003 8.089
2 TRBV6-2*01 TRBJ2-4*01 1 0.002 0.000 6.422

Final notes#

  • The paired VDJdb analysis excludes SLLMWITQV from focal embedding diagnostics.

  • Rationale: this epitope is driven by technique-specific enrichment around a single degenerate TCR pattern, which can distort broader repertoire geometry.

  • Provenance counts are still reported for transparency, while interpretation centers on non-excluded epitopes.

[11]:
# Emit concise text summary for manuscript-style reporting.
print('Summary')
print('-------')
print(f"Strict paired embedding retention: {strict_analysis['retention']:.3f}")
print(f"Imputed paired embedding retention: {imputed_analysis['retention']:.3f}")

for row in summary_df.sort('purity', descending=True).iter_rows(named=True):
    print(
        f"{row['mode']}: eps={row['eps']:.3f}, clusters={int(row['clusters'])}, "
        f"purity={row['purity']:.3f}, consistency@0.70={row['consistency_70']:.3f}, "
        f"retention={row['retention']:.3f}, median_4nn={row['median_4nn']:.3f}"
    )

print('\nExclusion policy: SLLMWITQV was excluded from focal diagnostics due to technique-specific degenerate-pattern enrichment.')
print('Key finding: FLRGRAYGL shows strong TRA-driven enrichment with distinctive V/J pairing signatures and remains separable in paired UMAP space.')
Summary
-------
Strict paired embedding retention: 0.678
Imputed paired embedding retention: 0.745
strict-paired: eps=0.717, clusters=104, purity=0.565, consistency@0.70=0.260, retention=0.678, median_4nn=0.734
strict-TRA-only: eps=0.319, clusters=215, purity=0.540, consistency@0.70=0.247, retention=0.597, median_4nn=0.365
strict-TRB-only: eps=0.289, clusters=190, purity=0.535, consistency@0.70=0.221, retention=0.636, median_4nn=0.300
imputed-paired: eps=0.739, clusters=74, purity=0.508, consistency@0.70=0.203, retention=0.745, median_4nn=0.751

Exclusion policy: SLLMWITQV was excluded from focal diagnostics due to technique-specific degenerate-pattern enrichment.
Key finding: FLRGRAYGL shows strong TRA-driven enrichment with distinctive V/J pairing signatures and remains separable in paired UMAP space.