TCREmp 10x dcode Analysis#

This notebook reproduces a comprehensive 10x dcode workflow with CITE-seq integration:

  • load donor-level paired VDJ plus binarized CITE-seq matrix into SingleCellSample;

  • sanity-check binder definitions against VDJdb records with 10x references;

  • derive donor1 pseudo-labels from positive binder channels;

  • run TRA/TRB/TRA_TRB embeddings with PCA, k-nearest-neighbor eps diagnostics, and UMAP views;

  • compute epitope-level classification scores.

Curation note:

  • KLGGALQAK is excluded from donor-level modeling in this notebook based on personal communication from 10x indicating this dcode target can be somewhat spurious for this setting.

[1]:
# Set deterministic seed, display settings, and print environment versions used in this analysis.
from __future__ import annotations

import random
import time
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
import sklearn
import umap
from IPython.display import display

from mir.common.single_cell import (
    load_10x_vdj_v1_citeseq_sample,
    validate_citeseq_binders_against_vdjdb_10x,
)
from mir.embedding.tcremp import PairedTCREmp
from mir.utils.embedding_diagnostics import (
    analyze_embedding_dbscan,
    classification_scores_by_label,
    majority_vote_cluster_predictions,
)
from mir.utils.notebook_assets import (
    ensure_airr_benchmark,
    find_airr_benchmark_dcode_10x_vdj_v1_donor,
    find_airr_benchmark_dcode_10x_vdj_v1_donor_matrix,
    find_airr_benchmark_vdjdb_full,
    find_repo_root,
)

SEED = 42
np.random.seed(SEED)
random.seed(SEED)

# Exclude dcode epitopes with known curation caveats for this analysis pass.
EXCLUDED_EPITOPES = {'KLGGALQAK'}

sns.set_theme(style='whitegrid', context='talk')
plt.rcParams['figure.dpi'] = 140
plt.rcParams['savefig.dpi'] = 300

print(f'Python: {__import__("sys").version.split()[0]}')
print(f'numpy: {np.__version__}')
print(f'polars: {pl.__version__}')
print(f'pandas: {pd.__version__}')
print(f'scikit-learn: {sklearn.__version__}')
print(f'umap-learn: {umap.__version__}')
print(f'Excluded epitopes in donor modeling: {sorted(EXCLUDED_EPITOPES)}')
/Users/mikesh/vcs/mirpy/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Python: 3.12.12
numpy: 1.26.4
polars: 1.40.1
pandas: 3.0.3
scikit-learn: 1.8.0
umap-learn: 0.5.12
Excluded epitopes in donor modeling: ['KLGGALQAK']
[2]:
# Load AIRR benchmark assets and build SingleCellSample objects for all 10x_vdj_v1 donors.
repo_root = find_repo_root(Path.cwd())
dataset_root = ensure_airr_benchmark(repo_root=repo_root, allow_patterns=['dcode/*', 'vdjdb/**'])
vdjdb_full = find_airr_benchmark_vdjdb_full(dataset_root)

donor_ids = ['donor1', 'donor2', 'donor3', 'donor4']
samples = {}
donor_paths = {}

for donor_id in donor_ids:
    all_contig, consensus = find_airr_benchmark_dcode_10x_vdj_v1_donor(dataset_root, donor_id)
    matrix = find_airr_benchmark_dcode_10x_vdj_v1_donor_matrix(dataset_root, donor_id)
    donor_paths[donor_id] = {'all_contig': all_contig, 'consensus': consensus, 'matrix': matrix}
    t0 = time.perf_counter()
    sample = load_10x_vdj_v1_citeseq_sample(
        consensus_annotations_path=consensus,
        all_contig_annotations_path=all_contig,
        binarized_matrix_path=matrix,
        sample_id=donor_id,
    )
    dt = time.perf_counter() - t0
    samples[donor_id] = sample
    print(
        f'{donor_id}: cells={sample.paired_repertoire.loaded_cell_count:,} '
        f'pairs={sample.paired_repertoire.paired_locus_repertoires["TRA_TRB"].clonotype_count:,} '
        f'matrix_rows={sample.cite_seq_matrix.height:,} binder_cols={sample.cite_seq_binder_columns.height:,} '
        f'time={dt:.2f}s'
    )
donor1: cells=47,271 pairs=48,890 matrix_rows=46,526 binder_cols=78 time=1.12s
donor2: cells=79,704 pairs=72,266 matrix_rows=77,854 binder_cols=78 time=1.77s
donor3: cells=38,095 pairs=39,518 matrix_rows=37,824 binder_cols=78 time=1.13s
donor4: cells=27,640 pairs=29,147 matrix_rows=27,308 binder_cols=78 time=0.68s
[3]:
# Validate donor binder definitions against VDJdb records annotated with 10x references.
sanity_rows = []
for donor_id, sample in samples.items():
    missing = validate_citeseq_binders_against_vdjdb_10x(sample.cite_seq_binder_columns, vdjdb_full)
    sanity_rows.append({
        'donor_id': donor_id,
        'binder_columns': sample.cite_seq_binder_columns.height,
        'missing_targets_vs_vdjdb_10x': missing.height,
    })
    if missing.height > 0:
        print(f'\n{donor_id} residual unmatched targets:')
        print(missing)

sanity_df = pl.DataFrame(sanity_rows).sort('donor_id')
print(sanity_df)

donor1 residual unmatched targets:
shape: (2, 5)
┌─────────────────────────────────┬───────┬─────────────────┬──────────────┬─────────────────┐
│ column                          ┆ hla   ┆ antigen.epitope ┆ antigen.gene ┆ antigen.species │
│ ---                             ┆ ---   ┆ ---             ┆ ---          ┆ ---             │
│ str                             ┆ str   ┆ str             ┆ str          ┆ str             │
╞═════════════════════════════════╪═══════╪═════════════════╪══════════════╪═════════════════╡
│ A0201_CLGGLLTMV_LMP-2A_EBV      ┆ A0201 ┆ CLGGLLTMV       ┆ LMP-2A       ┆ EBV             │
│ A0201_LLMGTLGIVC_HPV-16E7_82-9… ┆ A0201 ┆ LLMGTLGIVC      ┆ HPV-16E7     ┆ 82-91           │
└─────────────────────────────────┴───────┴─────────────────┴──────────────┴─────────────────┘

donor2 residual unmatched targets:
shape: (2, 5)
┌─────────────────────────────────┬───────┬─────────────────┬──────────────┬─────────────────┐
│ column                          ┆ hla   ┆ antigen.epitope ┆ antigen.gene ┆ antigen.species │
│ ---                             ┆ ---   ┆ ---             ┆ ---          ┆ ---             │
│ str                             ┆ str   ┆ str             ┆ str          ┆ str             │
╞═════════════════════════════════╪═══════╪═════════════════╪══════════════╪═════════════════╡
│ A0201_CLGGLLTMV_LMP-2A_EBV      ┆ A0201 ┆ CLGGLLTMV       ┆ LMP-2A       ┆ EBV             │
│ A0201_LLMGTLGIVC_HPV-16E7_82-9… ┆ A0201 ┆ LLMGTLGIVC      ┆ HPV-16E7     ┆ 82-91           │
└─────────────────────────────────┴───────┴─────────────────┴──────────────┴─────────────────┘

donor3 residual unmatched targets:
shape: (2, 5)
┌─────────────────────────────────┬───────┬─────────────────┬──────────────┬─────────────────┐
│ column                          ┆ hla   ┆ antigen.epitope ┆ antigen.gene ┆ antigen.species │
│ ---                             ┆ ---   ┆ ---             ┆ ---          ┆ ---             │
│ str                             ┆ str   ┆ str             ┆ str          ┆ str             │
╞═════════════════════════════════╪═══════╪═════════════════╪══════════════╪═════════════════╡
│ A0201_CLGGLLTMV_LMP-2A_EBV      ┆ A0201 ┆ CLGGLLTMV       ┆ LMP-2A       ┆ EBV             │
│ A0201_LLMGTLGIVC_HPV-16E7_82-9… ┆ A0201 ┆ LLMGTLGIVC      ┆ HPV-16E7     ┆ 82-91           │
└─────────────────────────────────┴───────┴─────────────────┴──────────────┴─────────────────┘

donor4 residual unmatched targets:
shape: (2, 5)
┌─────────────────────────────────┬───────┬─────────────────┬──────────────┬─────────────────┐
│ column                          ┆ hla   ┆ antigen.epitope ┆ antigen.gene ┆ antigen.species │
│ ---                             ┆ ---   ┆ ---             ┆ ---          ┆ ---             │
│ str                             ┆ str   ┆ str             ┆ str          ┆ str             │
╞═════════════════════════════════╪═══════╪═════════════════╪══════════════╪═════════════════╡
│ A0201_CLGGLLTMV_LMP-2A_EBV      ┆ A0201 ┆ CLGGLLTMV       ┆ LMP-2A       ┆ EBV             │
│ A0201_LLMGTLGIVC_HPV-16E7_82-9… ┆ A0201 ┆ LLMGTLGIVC      ┆ HPV-16E7     ┆ 82-91           │
└─────────────────────────────────┴───────┴─────────────────┴──────────────┴─────────────────┘
shape: (4, 3)
┌──────────┬────────────────┬──────────────────────────────┐
│ donor_id ┆ binder_columns ┆ missing_targets_vs_vdjdb_10x │
│ ---      ┆ ---            ┆ ---                          │
│ str      ┆ i64            ┆ i64                          │
╞══════════╪════════════════╪══════════════════════════════╡
│ donor1   ┆ 78             ┆ 2                            │
│ donor2   ┆ 78             ┆ 2                            │
│ donor3   ┆ 78             ┆ 2                            │
│ donor4   ┆ 78             ┆ 2                            │
└──────────┴────────────────┴──────────────────────────────┘
[4]:
# Build donor1 paired clonotype labels from positive CITE-seq binder channels.
# Deduplicate by paired (TRA + TRB) ntvj combination so each unique clonotype
# is counted once — cells sharing the same paired amino-acid sequence are pooled.
def donor1_binder_labels(sample):
    binder_cols = sample.cite_seq_binder_columns.filter(pl.col('is_binder') == '1').get_column('column').to_list()
    if not binder_cols:
        raise ValueError('No binder columns found in donor CITE-seq matrix')

    long = (
        sample.cite_seq_matrix
        .select(['barcode', *binder_cols])
        .unpivot(index='barcode', variable_name='column', value_name='value')
        .filter(pl.col('value') == 1)
    )
    if long.height == 0:
        raise ValueError('No positive binder events found')

    info = sample.cite_seq_binder_columns.select(['column', 'hla', 'antigen.epitope'])
    labels = (
        long.join(info, on='column', how='left')
        .group_by('barcode', 'antigen.epitope')
        .len()
        .sort(['barcode', 'len', 'antigen.epitope'], descending=[False, True, False])
        .group_by('barcode')
        .first()
        .select(['barcode', pl.col('antigen.epitope').alias('epitope_label')])
        .filter(~pl.col('epitope_label').is_in(sorted(EXCLUDED_EPITOPES)))
    )
    return labels


def _pair_ntvj_key(pair) -> tuple:
    """Canonical key: (TRA_aa, TRA_v, TRA_j, TRB_aa, TRB_v, TRB_j)."""
    c1, c2 = pair.clonotype1, pair.clonotype2
    tra = c1 if c1.locus == 'TRA' else c2
    trb = c2 if c2.locus == 'TRB' else c1
    return (
        tra.junction_aa or '', tra.v_gene or '', tra.j_gene or '',
        trb.junction_aa or '', trb.v_gene or '', trb.j_gene or '',
    )


donor1 = samples['donor1']
labels_df = donor1_binder_labels(donor1)

all_contig_df = pl.read_csv(donor_paths['donor1']['all_contig']).select(['barcode', 'raw_clonotype_id']).drop_nulls()
clonotype_labels = (
    all_contig_df.join(labels_df, on='barcode', how='inner')
    .group_by('raw_clonotype_id', 'epitope_label')
    .len()
    .sort(['raw_clonotype_id', 'len', 'epitope_label'], descending=[False, True, False])
    .group_by('raw_clonotype_id')
    .first()
    .rename({'raw_clonotype_id': 'clonotype_id'})
    .select(['clonotype_id', 'epitope_label'])
)
clonotype_label_map = {
    row['clonotype_id']: row['epitope_label']
    for row in clonotype_labels.iter_rows(named=True)
}

# Collect all pairs with CITE labels, then deduplicate by paired ntvj.
all_pairs_raw = []
all_rows_raw = []
for pair in donor1.paired_repertoire.paired_locus_repertoires['TRA_TRB'].paired_clonotypes:
    clonotype_id = pair.pair_id.split('_', 1)[0]
    epitope = clonotype_label_map.get(clonotype_id)
    if epitope is None:
        continue
    all_pairs_raw.append(pair)
    all_rows_raw.append({'clonotype_id': clonotype_id, 'epitope_label': epitope,
                         'ntvj_key': _pair_ntvj_key(pair)})

print(f'Paired clonotypes with CITE labels (before dedup): {len(all_pairs_raw):,}')

# Deduplicate: keep one representative pair per unique paired ntvj key.
# Epitope label: use most common label among cells sharing that key.
seen_keys: dict[tuple, dict] = {}
for pair, row in zip(all_pairs_raw, all_rows_raw):
    k = row['ntvj_key']
    if k not in seen_keys:
        seen_keys[k] = {'pair': pair, 'epitope_counts': {}}
    label = row['epitope_label']
    seen_keys[k]['epitope_counts'][label] = seen_keys[k]['epitope_counts'].get(label, 0) + 1

pairs = []
rows = []
for k, info in seen_keys.items():
    pairs.append(info['pair'])
    best_epitope = max(info['epitope_counts'], key=info['epitope_counts'].get)
    rows.append({'clonotype_id': info['pair'].pair_id.split('_', 1)[0],
                 'epitope_label': best_epitope})

label_table = pl.DataFrame(rows)
label_counts = label_table.group_by('epitope_label').len().sort('len', descending=True)

print(f'Unique paired clonotypes after ntvj dedup: {len(pairs):,}')
print(f'Deduplication removed {len(all_pairs_raw) - len(pairs):,} duplicate cells ({100*(len(all_pairs_raw)-len(pairs))/len(all_pairs_raw):.1f}%)')
print(label_counts.head(10))
display(label_counts.to_pandas())

Paired clonotypes with CITE labels (before dedup): 18,431
Unique paired clonotypes after ntvj dedup: 2,711
Deduplication removed 15,720 duplicate cells (85.3%)
shape: (10, 2)
┌───────────────┬─────┐
│ epitope_label ┆ len │
│ ---           ┆ --- │
│ str           ┆ u32 │
╞═══════════════╪═════╡
│ GILGFVFTL     ┆ 698 │
│ AVFDRKSDAK    ┆ 563 │
│ RAKFKQLL      ┆ 469 │
│ IVTDFSVIK     ┆ 444 │
│ ELAGIGILTV    ┆ 263 │
│ RLRAEAQVK     ┆ 105 │
│ GLCTLVAML     ┆ 31  │
│ LLDFVRFMGV    ┆ 27  │
│ MLDLQPETT     ┆ 13  │
│ SLFNTVATLY    ┆ 11  │
└───────────────┴─────┘
epitope_label len
0 GILGFVFTL 698
1 AVFDRKSDAK 563
2 RAKFKQLL 469
3 IVTDFSVIK 444
4 ELAGIGILTV 263
5 RLRAEAQVK 105
6 GLCTLVAML 31
7 LLDFVRFMGV 27
8 MLDLQPETT 13
9 SLFNTVATLY 11
10 CYTWNQMNL 11
11 FLYALALLL 10
12 AYAQKIFKI 9
13 RTLNAWVKV 8
14 YLLEMLWRL 7
15 YLNDHLEPWI 6
16 IMDQVPFSV 6
17 KTWGQYWQV 6
18 SLLMWITQV 5
19 FLRGRAYGL 4
20 QYDPVAALF 4
21 KVLEYVIKV 4
22 IPSINVHHY 3
23 QPRAPIRPI 2
24 SLFNTVATL 1
25 RIAAWMATY 1
[5]:
# Embed donor1 balanced clonotypes in TRA, TRB, and paired spaces and compute clustering/classification diagnostics.
MAX_PER_EPITOPE = 250
OTHER_LABEL = 'other'
TOP_LABELS = label_table.group_by('epitope_label').len().sort('len', descending=True).head(10).get_column('epitope_label').to_list()

balanced_parts = []
for label in TOP_LABELS:
    part = label_table.filter(pl.col('epitope_label') == label)
    balanced_parts.append(part.sample(n=min(MAX_PER_EPITOPE, part.height), seed=SEED, shuffle=True))

other = label_table.filter(~pl.col('epitope_label').is_in(TOP_LABELS))
if other.height > 0:
    balanced_parts.append(other.sample(n=min(500, other.height), seed=SEED, shuffle=True).with_columns(pl.lit(OTHER_LABEL).alias('epitope_label')))

balanced = pl.concat(balanced_parts, how='vertical')
pair_index = {pair.pair_id.split('_', 1)[0]: pair for pair in pairs}
balanced_pairs = [pair_index[c] for c in balanced.get_column('clonotype_id').to_list()]
balanced_labels = balanced.get_column('epitope_label').to_numpy()

model = PairedTCREmp.from_defaults(species='human', locus_pair='TRA_TRB', n_prototypes=500, junction_method='fixed_gap')

tra_only = [pair.clonotype1 if pair.clonotype1.locus == 'TRA' else pair.clonotype2 for pair in balanced_pairs]
trb_only = [pair.clonotype1 if pair.clonotype1.locus == 'TRB' else pair.clonotype2 for pair in balanced_pairs]

t0 = time.perf_counter()
X_pair = model.embed(balanced_pairs)
t_pair = time.perf_counter() - t0

t0 = time.perf_counter()
X_tra = model.chain1_model.embed(tra_only)
t_tra = time.perf_counter() - t0

t0 = time.perf_counter()
X_trb = model.chain2_model.embed(trb_only)
t_trb = time.perf_counter() - t0

mode_embeddings = {
    'TRA': X_tra,
    'TRB': X_trb,
    'TRA_TRB': X_pair,
}

mode_results = {}
mode_rows = []
score_rows = []

for mode, X_mode in mode_embeddings.items():
    analysis = analyze_embedding_dbscan(
        X_mode,
        balanced_labels,
        seed=SEED,
        min_samples=2,
        k_neighbors=4,
        pca_variance_threshold=0.90,
    )
    predicted = majority_vote_cluster_predictions(balanced_labels, analysis['clusters'])
    scores = classification_scores_by_label(balanced_labels, predicted)

    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        X_umap = umap.UMAP(
            n_components=2,
            n_neighbors=30,
            min_dist=0.10,
            metric='euclidean',
            random_state=SEED,
        ).fit_transform(analysis['X_pca'])

    mode_results[mode] = {
        'X': X_mode,
        'analysis': analysis,
        'predicted': predicted,
        'scores': scores,
        'X_umap': X_umap,
    }

    mode_rows.append(
        {
            'mode': mode,
            'n_records': int(len(balanced_labels)),
            'embed_time_s': float({'TRA': t_tra, 'TRB': t_trb, 'TRA_TRB': t_pair}[mode]),
            'pcs_90pct': int(analysis['n_comp']),
            'eps': float(analysis['eps']),
            'clusters': int(analysis['n_clusters']),
            'retention': float(analysis['retention']),
            'purity': float(analysis['purity']),
            'consistency_70': float(analysis['consistency']),
            'macro_f1': float(scores['macro_f1']),
            'weighted_f1': float(scores['weighted_f1']),
            'accuracy': float(scores['accuracy']),
        }
    )

    for rec in scores['per_label']:
        score_rows.append(
            {
                'mode': mode,
                'epitope': rec['label'],
                'precision': float(rec['precision']),
                'recall': float(rec['recall']),
                'f1': float(rec['f1']),
                'support': int(rec['support']),
            }
        )

mode_summary = pl.DataFrame(mode_rows).sort('mode')
epitope_scores = pl.DataFrame(score_rows).sort(['mode', 'f1'], descending=[False, True])

print(f'Balanced donor1 curated records: {balanced.height:,}')
print(f'Embedding runtime TRA: {t_tra:.3f}s  TRB: {t_trb:.3f}s  TRA_TRB: {t_pair:.3f}s')
print(mode_summary)
print('\nTop epitope classification rows by F1 per mode:')
print(epitope_scores.group_by('mode').head(8).sort(['mode', 'f1'], descending=[False, True]))

display(mode_summary.to_pandas().style.background_gradient(subset=['weighted_f1', 'accuracy'], cmap='YlGnBu').format(precision=3))
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
Balanced donor1 curated records: 1,524
Embedding runtime TRA: 0.020s  TRB: 0.021s  TRA_TRB: 0.044s
shape: (3, 12)
┌─────────┬───────────┬────────────┬───────────┬───┬────────────┬──────────┬────────────┬──────────┐
│ mode    ┆ n_records ┆ embed_time ┆ pcs_90pct ┆ … ┆ consistenc ┆ macro_f1 ┆ weighted_f ┆ accuracy │
│ ---     ┆ ---       ┆ _s         ┆ ---       ┆   ┆ y_70       ┆ ---      ┆ 1          ┆ ---      │
│ str     ┆ i64       ┆ ---        ┆ i64       ┆   ┆ ---        ┆ f64      ┆ ---        ┆ f64      │
│         ┆           ┆ f64        ┆           ┆   ┆ f64        ┆          ┆ f64        ┆          │
╞═════════╪═══════════╪════════════╪═══════════╪═══╪════════════╪══════════╪════════════╪══════════╡
│ TRA     ┆ 1524      ┆ 0.019837   ┆ 29        ┆ … ┆ 0.586066   ┆ 0.478703 ┆ 0.622732   ┆ 0.549869 │
│ TRA_TRB ┆ 1524      ┆ 0.043945   ┆ 52        ┆ … ┆ 0.809816   ┆ 0.366964 ┆ 0.425415   ┆ 0.37664  │
│ TRB     ┆ 1524      ┆ 0.020975   ┆ 26        ┆ … ┆ 0.703349   ┆ 0.500222 ┆ 0.599816   ┆ 0.528215 │
└─────────┴───────────┴────────────┴───────────┴───┴────────────┴──────────┴────────────┴──────────┘

Top epitope classification rows by F1 per mode:
shape: (24, 6)
┌──────┬────────────┬───────────┬──────────┬───────────┬─────────┐
│ mode ┆ epitope    ┆ precision ┆ recall   ┆ f1        ┆ support │
│ ---  ┆ ---        ┆ ---       ┆ ---      ┆ ---       ┆ ---     │
│ str  ┆ str        ┆ f64       ┆ f64      ┆ f64       ┆ i64     │
╞══════╪════════════╪═══════════╪══════════╪═══════════╪═════════╡
│ TRA  ┆ GILGFVFTL  ┆ 0.73444   ┆ 0.708    ┆ 0.720978  ┆ 250     │
│ TRA  ┆ IVTDFSVIK  ┆ 0.850575  ┆ 0.592    ┆ 0.698113  ┆ 250     │
│ TRA  ┆ RAKFKQLL   ┆ 0.739535  ┆ 0.636    ┆ 0.683871  ┆ 250     │
│ TRA  ┆ LLDFVRFMGV ┆ 0.75      ┆ 0.555556 ┆ 0.638298  ┆ 27      │
│ TRA  ┆ ELAGIGILTV ┆ 0.633803  ┆ 0.54     ┆ 0.583153  ┆ 250     │
│ …    ┆ …          ┆ …         ┆ …        ┆ …         ┆ …       │
│ TRB  ┆ RAKFKQLL   ┆ 0.603053  ┆ 0.632    ┆ 0.6171875 ┆ 250     │
│ TRB  ┆ AVFDRKSDAK ┆ 0.728477  ┆ 0.44     ┆ 0.548628  ┆ 250     │
│ TRB  ┆ MLDLQPETT  ┆ 0.833333  ┆ 0.384615 ┆ 0.526316  ┆ 13      │
│ TRB  ┆ ELAGIGILTV ┆ 0.869159  ┆ 0.372    ┆ 0.521008  ┆ 250     │
│ TRB  ┆ RLRAEAQVK  ┆ 0.939394  ┆ 0.295238 ┆ 0.449275  ┆ 105     │
└──────┴────────────┴───────────┴──────────┴───────────┴─────────┘
  mode n_records embed_time_s pcs_90pct eps clusters retention purity consistency_70 macro_f1 weighted_f1 accuracy
0 TRA 1524 0.020 29 0.331 244 0.741 0.789 0.586 0.479 0.623 0.550
1 TRA_TRB 1524 0.044 52 0.655 163 0.732 0.904 0.810 0.367 0.425 0.377
2 TRB 1524 0.021 26 0.391 209 0.719 0.851 0.703 0.500 0.600 0.528
[6]:
# Plot cumulative PCA variance and k-nearest-neighbor distance curves for eps selection.
fig, axes = plt.subplots(2, 3, figsize=(18, 9.5), constrained_layout=True)

for col, mode in enumerate(['TRA', 'TRB', 'TRA_TRB']):
    analysis = mode_results[mode]['analysis']

    max_show = min(220, len(analysis['cum']))
    ax = axes[0, col]
    ax.plot(np.arange(1, max_show + 1), analysis['cum'][:max_show] * 100, lw=2.3, color='#006d77')
    ax.axhline(90, color='#222222', ls='--', lw=1.2)
    ax.axvline(analysis['n_comp'], color='#222222', ls=':', lw=1.2)
    ax.set_title(f'{mode} | Cumulative PCA Variance', fontsize=13, pad=10)
    ax.set_xlabel('Principal component index', fontsize=11)
    ax.set_ylabel('Cumulative variance (%)', fontsize=11)
    ax.set_ylim(0, 101)

    kth = analysis['kth']
    ax = axes[1, col]
    ax.plot(kth, lw=1.8, color='#1d4e89')
    if analysis['knee_idx'] is not None:
        ax.axvline(int(analysis['knee_idx']), color='#222222', ls=':', lw=1.1, label='knee idx')
    ax.axhline(float(analysis['eps']), color='#b00020', ls='--', lw=1.1, label='eps')
    ax.set_title(f'{mode} | 4-NN Distance Profile', fontsize=13, pad=10)
    ax.set_xlabel('Sorted clonotype index', fontsize=11)
    ax.set_ylabel('4-nearest-neighbor distance', fontsize=11)
    ax.legend(loc='upper left', frameon=False, fontsize=9)

fig.suptitle('Donor1 Curated Embedding Diagnostics', fontsize=16, y=1.02)
plt.show()
../_images/notebooks_tcremp_10xdcode_analysis_6_0.png
[7]:
# Plot PCA/UMAP projections colored by epitope and by cluster assignments.
def _make_plot_frame(coords, labels, clusters):
    return pl.DataFrame({
        'x': coords[:, 0],
        'y': coords[:, 1],
        'epitope': labels,
        'cluster': clusters.astype(int),
    }).with_columns(
        pl.when(pl.col('cluster') == -1)
        .then(pl.lit('noise'))
        .otherwise(pl.format('cluster_{}', pl.col('cluster')))
        .alias('cluster_label')
    )


def _plot_by_epitope(ax, frame, title):
    top_labels = (
        frame.group_by('epitope').len().sort('len', descending=True).head(10).get_column('epitope').to_list()
    )
    display_frame = frame.with_columns(
        pl.when(pl.col('epitope').is_in(top_labels)).then(pl.col('epitope')).otherwise(pl.lit('other')).alias('epitope_disp')
    )
    sns.scatterplot(
        data=display_frame.to_pandas(),
        x='x',
        y='y',
        hue='epitope_disp',
        palette='tab20',
        s=16,
        alpha=0.76,
        linewidth=0,
        ax=ax,
    )
    ax.set_title(title, fontsize=12)
    ax.set_xlabel('Dim 1', fontsize=10)
    ax.set_ylabel('Dim 2', fontsize=10)


def _plot_by_cluster(ax, frame, title):
    top_clusters = (
        frame.filter(pl.col('cluster') != -1)
        .group_by('cluster_label')
        .len()
        .sort('len', descending=True)
        .head(10)
        .get_column('cluster_label')
        .to_list()
    )
    display_frame = frame.with_columns(
        pl.when((pl.col('cluster_label').is_in(top_clusters)) | (pl.col('cluster_label') == 'noise'))
        .then(pl.col('cluster_label'))
        .otherwise(pl.lit('other_cluster'))
        .alias('cluster_disp')
    )
    sns.scatterplot(
        data=display_frame.to_pandas(),
        x='x',
        y='y',
        hue='cluster_disp',
        palette='tab20b',
        s=15,
        alpha=0.78,
        linewidth=0,
        ax=ax,
    )
    ax.set_title(title, fontsize=12)
    ax.set_xlabel('Dim 1', fontsize=10)
    ax.set_ylabel('Dim 2', fontsize=10)


fig, axes = plt.subplots(3, 3, figsize=(18, 13), constrained_layout=True)
for col, mode in enumerate(['TRA', 'TRB', 'TRA_TRB']):
    pca_coords = mode_results[mode]['analysis']['X_pca'][:, :2]
    umap_coords = mode_results[mode]['X_umap']
    clusters = mode_results[mode]['analysis']['clusters']

    pca_frame = _make_plot_frame(pca_coords, balanced_labels, clusters)
    umap_frame = _make_plot_frame(umap_coords, balanced_labels, clusters)

    _plot_by_epitope(axes[0, col], pca_frame, f'{mode} | PCA by epitope')
    _plot_by_epitope(axes[1, col], umap_frame, f'{mode} | UMAP by epitope')
    _plot_by_cluster(axes[2, col], umap_frame, f'{mode} | UMAP by cluster')

for r in range(3):
    for c in range(3):
        if axes[r, c].legend_ is not None:
            axes[r, c].legend_.remove()

handles_e, labels_e = axes[0, 2].get_legend_handles_labels()
handles_c, labels_c = axes[2, 2].get_legend_handles_labels()
fig.legend(handles_e, labels_e, title='Epitope', loc='center left', bbox_to_anchor=(1.01, 0.67), frameon=False, fontsize=9)
fig.legend(handles_c, labels_c, title='Cluster', loc='center left', bbox_to_anchor=(1.01, 0.30), frameon=False, fontsize=9)
fig.suptitle('Donor1 Curated Embeddings: Epitope and Cluster Colorings', fontsize=16, y=1.02)
plt.show()
../_images/notebooks_tcremp_10xdcode_analysis_7_0.png
[8]:
# Display epitope-level classification score tables by mode with pandas styling.
for mode in ['TRA', 'TRB', 'TRA_TRB']:
    print(f'\n{mode} epitope classification scores (top 15 by support):')
    display_df = (
        epitope_scores
        .filter(pl.col('mode') == mode)
        .sort(['support', 'f1'], descending=[True, True])
        .head(15)
    )
    pd_df = display_df.to_pandas()
    display(
        pd_df.style
        .background_gradient(subset=['precision', 'recall', 'f1'], cmap='YlOrRd')
        .format({'precision': '{:.3f}', 'recall': '{:.3f}', 'f1': '{:.3f}'})
    )

print('\nGlobal mode summary with clustering and classification metrics:')
summary_pd = mode_summary.sort('weighted_f1', descending=True).to_pandas()
display(
    summary_pd.style
    .background_gradient(subset=['purity', 'consistency_70', 'weighted_f1', 'accuracy'], cmap='YlGnBu')
    .format(precision=3)
)

TRA epitope classification scores (top 15 by support):
  mode epitope precision recall f1 support
0 TRA GILGFVFTL 0.734 0.708 0.721 250
1 TRA IVTDFSVIK 0.851 0.592 0.698 250
2 TRA RAKFKQLL 0.740 0.636 0.684 250
3 TRA ELAGIGILTV 0.634 0.540 0.583 250
4 TRA AVFDRKSDAK 0.729 0.484 0.582 250
5 TRA RLRAEAQVK 0.833 0.381 0.523 105
6 TRA other 0.929 0.299 0.452 87
7 TRA GLCTLVAML 0.652 0.484 0.556 31
8 TRA LLDFVRFMGV 0.750 0.556 0.638 27
9 TRA MLDLQPETT 0.000 0.000 0.000 13
10 TRA CYTWNQMNL 1.000 0.182 0.308 11

TRB epitope classification scores (top 15 by support):
  mode epitope precision recall f1 support
0 TRB IVTDFSVIK 0.911 0.652 0.760 250
1 TRB GILGFVFTL 0.643 0.772 0.702 250
2 TRB RAKFKQLL 0.603 0.632 0.617 250
3 TRB AVFDRKSDAK 0.728 0.440 0.549 250
4 TRB ELAGIGILTV 0.869 0.372 0.521 250
5 TRB RLRAEAQVK 0.939 0.295 0.449 105
6 TRB other 0.957 0.253 0.400 87
7 TRB GLCTLVAML 0.955 0.677 0.792 31
8 TRB LLDFVRFMGV 0.700 0.259 0.378 27
9 TRB MLDLQPETT 0.833 0.385 0.526 13
10 TRB CYTWNQMNL 1.000 0.182 0.308 11

TRA_TRB epitope classification scores (top 15 by support):
  mode epitope precision recall f1 support
0 TRA_TRB AVFDRKSDAK 0.809 0.372 0.510 250
1 TRA_TRB RAKFKQLL 0.914 0.340 0.496 250
2 TRA_TRB IVTDFSVIK 0.305 0.844 0.448 250
3 TRA_TRB ELAGIGILTV 0.917 0.264 0.410 250
4 TRA_TRB GILGFVFTL 0.760 0.152 0.253 250
5 TRA_TRB RLRAEAQVK 0.917 0.314 0.468 105
6 TRA_TRB other 0.893 0.287 0.435 87
7 TRA_TRB GLCTLVAML 0.615 0.258 0.364 31
8 TRA_TRB LLDFVRFMGV 0.846 0.407 0.550 27
9 TRA_TRB MLDLQPETT 1.000 0.308 0.471 13
10 TRA_TRB CYTWNQMNL 0.000 0.000 0.000 11

Global mode summary with clustering and classification metrics:
  mode n_records embed_time_s pcs_90pct eps clusters retention purity consistency_70 macro_f1 weighted_f1 accuracy
0 TRA 1524 0.020 29 0.331 244 0.741 0.789 0.586 0.479 0.623 0.550
1 TRB 1524 0.021 26 0.391 209 0.719 0.851 0.703 0.500 0.600 0.528
2 TRA_TRB 1524 0.044 52 0.655 163 0.732 0.904 0.810 0.367 0.425 0.377