Cross-Cohort Signature Correspondence Analysis¶
This notebook quantifies the reproducibility of disease signatures across biobanks (UK Biobank, Mass General Brigham, All of Us) using:
- Cross-tabulation matrices: Normalized confusion matrices showing cluster correspondence
- Modified Jaccard similarity: Set-based metric for quantifying signature overlap
Key metric: Median modified Jaccard similarity = 0.800 (80.0%) across all UKB signatures
Setup and Data Loading¶
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 300
plt.rcParams['font.size'] = 10
# Output directory
OUTPUT_DIR = Path('.')
OUTPUT_DIR.mkdir(exist_ok=True)
print("Setup complete")
Setup complete
mgb_checkpoint_path = '/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_retrospective_full/enrollment_model_W0.0001_batch_0_10000.pt'
# Load model checkpoints
print("="*80)
print("LOADING MODEL CHECKPOINTS")
print("="*80)
ukb_checkpoint_path = '/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_retrospective_full/enrollment_model_W0.0001_batch_0_10000.pt'
mgb_checkpoint_path = '/Users/sarahurbut/Dropbox-Personal/model_with_kappa_bigam_MGB.pt'
aou_checkpoint_path = '/Users/sarahurbut/Dropbox-Personal/model_with_kappa_bigam_AOU.pt'
ukb_checkpoint = torch.load(ukb_checkpoint_path, map_location='cpu')
mgb_checkpoint = torch.load(mgb_checkpoint_path, map_location='cpu')
aou_checkpoint = torch.load(aou_checkpoint_path, map_location='cpu')
print(f"\n✓ UKB checkpoint loaded")
print(f"✓ MGB checkpoint loaded")
print(f"✓ AoU checkpoint loaded")
# Extract disease names and clusters
ukb_diseases = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/disease_names.csv')['x']
mgb_diseases = mgb_checkpoint['disease_names']
aou_diseases = aou_checkpoint['disease_names']
ukb_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/initial_clusters_400k.pt')
mgb_clusters = mgb_checkpoint['clusters']
aou_clusters = aou_checkpoint['clusters']
# Convert to lists if needed
if hasattr(ukb_diseases, 'values'):
ukb_diseases = ukb_diseases.values.tolist()
if hasattr(mgb_diseases, 'values'):
mgb_diseases = mgb_diseases.values.tolist()
if hasattr(aou_diseases, 'values'):
aou_diseases = aou_diseases.values.tolist()
# Convert clusters to numpy arrays if needed
if isinstance(ukb_clusters, torch.Tensor):
ukb_clusters = ukb_clusters.numpy()
if isinstance(mgb_clusters, torch.Tensor):
mgb_clusters = mgb_clusters.numpy()
if isinstance(aou_clusters, torch.Tensor):
aou_clusters = aou_clusters.numpy()
print(f"\nUKB: {len(ukb_diseases)} diseases, {ukb_clusters.max()+1} signatures")
print(f"MGB: {len(mgb_diseases)} diseases, {mgb_clusters.max()+1} signatures")
print(f"AoU: {len(aou_diseases)} diseases, {aou_clusters.max()+1} signatures")
================================================================================ LOADING MODEL CHECKPOINTS ================================================================================ ✓ UKB checkpoint loaded ✓ MGB checkpoint loaded ✓ AoU checkpoint loaded UKB: 348 diseases, 20 signatures MGB: 346 diseases, 20 signatures AoU: 348 diseases, 20 signatures
Create DataFrames and Find Common Diseases¶
# Create DataFrames for each biobank
ukb_df = pd.DataFrame({
'Disease': ukb_diseases,
'UKB_cluster': ukb_clusters
})
mgb_df = pd.DataFrame({
'Disease': mgb_diseases,
'MGB_cluster': mgb_clusters
})
aou_df = pd.DataFrame({
'Disease': aou_diseases,
'AoU_cluster': aou_clusters
})
# Find common diseases between biobanks
common_ukb_mgb = list(set(ukb_diseases) & set(mgb_diseases))
common_ukb_aou = list(set(ukb_diseases) & set(aou_diseases))
common_all_three = list(set(ukb_diseases) & set(mgb_diseases) & set(aou_diseases))
print("="*80)
print("COMMON DISEASES")
print("="*80)
print(f"\nUKB ↔ MGB: {len(common_ukb_mgb)} common diseases")
print(f"UKB ↔ AoU: {len(common_ukb_aou)} common diseases")
print(f"All three biobanks: {len(common_all_three)} common diseases")
================================================================================ COMMON DISEASES ================================================================================ UKB ↔ MGB: 346 common diseases UKB ↔ AoU: 348 common diseases All three biobanks: 346 common diseases
Cross-Tabulation Matrices¶
# Create cross-tabulation matrices (normalized by row)
print("="*80)
print("CROSS-TABULATION MATRICES")
print("="*80)
# Merge dataframes for common diseases
df_ukb_mgb = pd.DataFrame({'Disease': common_ukb_mgb})
df_ukb_mgb = df_ukb_mgb.merge(ukb_df, on='Disease', how='left')
df_ukb_mgb = df_ukb_mgb.merge(mgb_df, on='Disease', how='left')
df_ukb_aou = pd.DataFrame({'Disease': common_ukb_aou})
df_ukb_aou = df_ukb_aou.merge(ukb_df, on='Disease', how='left')
df_ukb_aou = df_ukb_aou.merge(aou_df, on='Disease', how='left')
# Create cross-tabulation matrices (normalized by row)
cross_tab_mgb = pd.crosstab(
df_ukb_mgb['UKB_cluster'],
df_ukb_mgb['MGB_cluster'],
normalize='index'
)
cross_tab_aou = pd.crosstab(
df_ukb_aou['UKB_cluster'],
df_ukb_aou['AoU_cluster'],
normalize='index'
)
print(f"\n✓ Cross-tabulation matrix (UKB vs MGB): {cross_tab_mgb.shape}")
print(f"✓ Cross-tabulation matrix (UKB vs AoU): {cross_tab_aou.shape}")
# Find best matches for ordering
best_matches_mgb = pd.DataFrame({
'UKB': cross_tab_mgb.index,
'MGB': cross_tab_mgb.idxmax(axis=1),
'Proportion': cross_tab_mgb.max(axis=1).values
}).sort_values('MGB')
best_matches_aou = pd.DataFrame({
'UKB': cross_tab_aou.index,
'AoU': cross_tab_aou.idxmax(axis=1),
'Proportion': cross_tab_aou.max(axis=1).values
}).sort_values('AoU')
print(f"\n✓ Best matches identified for visualization")
================================================================================ CROSS-TABULATION MATRICES ================================================================================ ✓ Cross-tabulation matrix (UKB vs MGB): (20, 20) ✓ Cross-tabulation matrix (UKB vs AoU): (20, 20) ✓ Best matches identified for visualization
Calculate Modified Jaccard Similarity¶
The modified Jaccard similarity for each UKB cluster $k$ is defined as:
$$J_k = \\frac{|D_{k,\\text{UKB}} \\cap D_{k^*,\\text{other}}|}{|D_{k,\\text{UKB}}|}$$
where:
- $D_{k,\\text{UKB}}$ is the set of diseases in UKB cluster $k$
- $D_{k^*,\\text{other}}$ is the set of diseases in the best-matching cluster $k^*$ in the comparison cohort
- $|\\cdot|$ denotes set cardinality
The overall cross-cohort similarity is the median of these cluster-specific similarities.
def calculate_modified_jaccard(ukb_df, other_df, other_col_name, common_diseases, biobank_name):
"""
Calculate modified Jaccard similarity using set intersections.
For each UKB cluster k:
1. Get set of diseases in UKB cluster k (from common diseases)
2. For each cluster in other biobank, calculate intersection
3. Find best-matching cluster (maximum intersection / UKB cluster size)
4. Store Jaccard score for cluster k
Returns:
--------
jaccard_scores : list
List of Jaccard scores for each UKB cluster
cluster_details : dict
Dictionary with details for each cluster
"""
# Merge on common diseases
merged = ukb_df[ukb_df['Disease'].isin(common_diseases)].merge(
other_df[other_df['Disease'].isin(common_diseases)],
on='Disease', how='inner'
)
# Group by UKB cluster
jaccard_scores = []
cluster_details = {}
for ukb_cluster in sorted(merged['UKB_cluster'].unique()):
ukb_diseases = set(merged[merged['UKB_cluster'] == ukb_cluster]['Disease'])
if len(ukb_diseases) == 0:
continue
# Find best matching cluster in other biobank
best_match_score = 0
best_match_cluster = None
best_intersection = set()
for other_cluster in sorted(merged[other_col_name].unique()):
other_diseases = set(merged[merged[other_col_name] == other_cluster]['Disease'])
intersection = ukb_diseases & other_diseases
jaccard_k = len(intersection) / len(ukb_diseases)
if jaccard_k > best_match_score:
best_match_score = jaccard_k
best_match_cluster = other_cluster
best_intersection = intersection
jaccard_scores.append(best_match_score)
cluster_details[ukb_cluster] = {
'jaccard': best_match_score,
'best_match': best_match_cluster,
'ukb_size': len(ukb_diseases),
'intersection_size': len(best_intersection)
}
print(f" UKB cluster {ukb_cluster:2d} → {biobank_name} cluster {best_match_cluster:2d}: "
f"J = {best_match_score:.3f} ({len(best_intersection)}/{len(ukb_diseases)} diseases)")
return jaccard_scores, cluster_details
print("="*80)
print("MODIFIED JACCARD SIMILARITY CALCULATION")
print("="*80)
print("\n" + "-"*80)
print("UKB ↔ MGB:")
print("-"*80)
jaccard_mgb, details_mgb = calculate_modified_jaccard(
ukb_df, mgb_df, 'MGB_cluster', common_ukb_mgb, 'MGB'
)
print("\n" + "-"*80)
print("UKB ↔ AoU:")
print("-"*80)
jaccard_aou, details_aou = calculate_modified_jaccard(
ukb_df, aou_df, 'AoU_cluster', common_ukb_aou, 'AoU'
)
# Calculate summary statistics
median_mgb = np.median(jaccard_mgb)
median_aou = np.median(jaccard_aou)
combined_median = np.median(jaccard_mgb + jaccard_aou)
average_medians = (median_mgb + median_aou) / 2
print("\n" + "="*80)
print("SUMMARY STATISTICS")
print("="*80)
print(f"\nMedian Jaccard (UKB ↔ MGB): {median_mgb:.3f} ({median_mgb*100:.1f}%)")
print(f"Median Jaccard (UKB ↔ AoU): {median_aou:.3f} ({median_aou*100:.1f}%)")
print(f"\nCombined median Jaccard: {combined_median:.3f} ({combined_median*100:.1f}%)")
print(f"Average of medians: {average_medians:.3f} ({average_medians*100:.1f}%)")
print(f"\nIQR (UKB ↔ MGB): [{np.percentile(jaccard_mgb, 25):.3f}, {np.percentile(jaccard_mgb, 75):.3f}]")
print(f"IQR (UKB ↔ AoU): [{np.percentile(jaccard_aou, 25):.3f}, {np.percentile(jaccard_aou, 75):.3f}]")
================================================================================ MODIFIED JACCARD SIMILARITY CALCULATION ================================================================================ -------------------------------------------------------------------------------- UKB ↔ MGB: -------------------------------------------------------------------------------- UKB cluster 0 → MGB cluster 5: J = 0.929 (13/14 diseases) UKB cluster 1 → MGB cluster 2: J = 0.952 (20/21 diseases) UKB cluster 2 → MGB cluster 7: J = 0.933 (14/15 diseases) UKB cluster 3 → MGB cluster 12: J = 0.561 (46/82 diseases) UKB cluster 4 → MGB cluster 0: J = 1.000 (5/5 diseases) UKB cluster 5 → MGB cluster 5: J = 0.571 (4/7 diseases) UKB cluster 6 → MGB cluster 11: J = 0.875 (7/8 diseases) UKB cluster 7 → MGB cluster 1: J = 0.500 (11/22 diseases) UKB cluster 8 → MGB cluster 19: J = 0.679 (19/28 diseases) UKB cluster 9 → MGB cluster 9: J = 1.000 (12/12 diseases) UKB cluster 10 → MGB cluster 16: J = 1.000 (11/11 diseases) UKB cluster 11 → MGB cluster 10: J = 0.875 (7/8 diseases) UKB cluster 12 → MGB cluster 3: J = 0.714 (5/7 diseases) UKB cluster 13 → MGB cluster 18: J = 0.615 (8/13 diseases) UKB cluster 14 → MGB cluster 14: J = 0.900 (9/10 diseases) UKB cluster 15 → MGB cluster 6: J = 0.800 (4/5 diseases) UKB cluster 16 → MGB cluster 4: J = 0.690 (20/29 diseases) UKB cluster 17 → MGB cluster 8: J = 0.882 (15/17 diseases) UKB cluster 18 → MGB cluster 13: J = 0.778 (7/9 diseases) UKB cluster 19 → MGB cluster 15: J = 0.391 (9/23 diseases) -------------------------------------------------------------------------------- UKB ↔ AoU: -------------------------------------------------------------------------------- UKB cluster 0 → AoU cluster 16: J = 0.812 (13/16 diseases) UKB cluster 1 → AoU cluster 6: J = 0.667 (14/21 diseases) UKB cluster 2 → AoU cluster 12: J = 1.000 (15/15 diseases) UKB cluster 3 → AoU cluster 5: J = 0.451 (37/82 diseases) UKB cluster 4 → AoU cluster 3: J = 1.000 (5/5 diseases) UKB cluster 5 → AoU cluster 16: J = 0.714 (5/7 diseases) UKB cluster 6 → AoU cluster 11: J = 1.000 (8/8 diseases) UKB cluster 7 → AoU cluster 1: J = 0.182 (4/22 diseases) UKB cluster 8 → AoU cluster 7: J = 0.750 (21/28 diseases) UKB cluster 9 → AoU cluster 9: J = 1.000 (12/12 diseases) UKB cluster 10 → AoU cluster 8: J = 1.000 (11/11 diseases) UKB cluster 11 → AoU cluster 4: J = 1.000 (8/8 diseases) UKB cluster 12 → AoU cluster 15: J = 1.000 (7/7 diseases) UKB cluster 13 → AoU cluster 5: J = 0.538 (7/13 diseases) UKB cluster 14 → AoU cluster 19: J = 0.900 (9/10 diseases) UKB cluster 15 → AoU cluster 14: J = 0.800 (4/5 diseases) UKB cluster 16 → AoU cluster 1: J = 0.690 (20/29 diseases) UKB cluster 17 → AoU cluster 12: J = 0.765 (13/17 diseases) UKB cluster 18 → AoU cluster 10: J = 0.667 (6/9 diseases) UKB cluster 19 → AoU cluster 13: J = 0.391 (9/23 diseases) ================================================================================ SUMMARY STATISTICS ================================================================================ Median Jaccard (UKB ↔ MGB): 0.838 (83.8%) Median Jaccard (UKB ↔ AoU): 0.782 (78.2%) Combined median Jaccard: 0.800 (80.0%) Average of medians: 0.810 (81.0%) IQR (UKB ↔ MGB): [0.663, 0.930] IQR (UKB ↔ AoU): [0.667, 1.000]
# Create heatmaps
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
# UKB vs MGB
sns.heatmap(
cross_tab_mgb.loc[best_matches_mgb['UKB']],
cmap='Reds',
vmin=0, vmax=1,
ax=ax1,
cbar_kws={'label': 'Proportion'},
linewidths=0.5,
linecolor='gray'
)
ax1.set_title('Cluster Correspondence: UKB vs MGB\n(Common Diseases Only)',
fontsize=12, fontweight='bold', pad=15)
ax1.set_xlabel('MGB Cluster', fontsize=11, fontweight='bold')
ax1.set_ylabel('UKB Cluster', fontsize=11, fontweight='bold')
# UKB vs AoU
sns.heatmap(
cross_tab_aou.loc[best_matches_aou['UKB']],
cmap='Reds',
vmin=0, vmax=1,
ax=ax2,
cbar_kws={'label': 'Proportion'},
linewidths=0.5,
linecolor='gray'
)
ax2.set_title('Cluster Correspondence: UKB vs AoU\n(Common Diseases Only)',
fontsize=12, fontweight='bold', pad=15)
ax2.set_xlabel('AoU Cluster', fontsize=11, fontweight='bold')
ax2.set_ylabel('UKB Cluster', fontsize=11, fontweight='bold')
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'cross_cohort_correspondence_heatmaps.pdf',
dpi=300, bbox_inches='tight')
plt.savefig(OUTPUT_DIR / 'cross_cohort_correspondence_heatmaps.png',
dpi=300, bbox_inches='tight')
print(f"✓ Saved heatmaps to: {OUTPUT_DIR / 'cross_cohort_correspondence_heatmaps.pdf'}")
plt.show()
✓ Saved heatmaps to: cross_cohort_correspondence_heatmaps.pdf
2. Jaccard Similarity by Cluster¶
# Create bar plot of Jaccard similarities
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
# Prepare data
ukb_clusters_mgb = sorted([k for k in details_mgb.keys()])
jaccard_values_mgb = [details_mgb[k]['jaccard'] for k in ukb_clusters_mgb]
ukb_clusters_aou = sorted([k for k in details_aou.keys()])
jaccard_values_aou = [details_aou[k]['jaccard'] for k in ukb_clusters_aou]
# UKB vs MGB
bars1 = ax1.bar(ukb_clusters_mgb, jaccard_values_mgb,
color='steelblue', alpha=0.7, edgecolor='black', linewidth=1)
ax1.axhline(y=median_mgb, color='red', linestyle='--', linewidth=2,
label=f'Median = {median_mgb:.3f}')
ax1.set_xlabel('UKB Cluster', fontsize=11, fontweight='bold')
ax1.set_ylabel('Composition Preservation Probability', fontsize=11, fontweight='bold')
ax1.set_title('Composition Preservation Probability: UKB ↔ MGB\n(by UKB Cluster)',
fontsize=12, fontweight='bold', pad=15)
ax1.set_ylim([0, 1.1])
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3, axis='y')
# Add value labels on bars
for i, (bar, val) in enumerate(zip(bars1, jaccard_values_mgb)):
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height + 0.02,
f'{val:.2f}', ha='center', va='bottom', fontsize=8)
# UKB vs AoU
bars2 = ax2.bar(ukb_clusters_aou, jaccard_values_aou,
color='darkgreen', alpha=0.7, edgecolor='black', linewidth=1)
ax2.axhline(y=median_aou, color='red', linestyle='--', linewidth=2,
label=f'Median = {median_aou:.3f}')
ax2.set_xlabel('UKB Cluster', fontsize=11, fontweight='bold')
ax2.set_ylabel('Composition Preservation Probability', fontsize=11, fontweight='bold')
ax2.set_title('Composition Preservation Probability: UKB ↔ AoU\n(by UKB Cluster)',
fontsize=12, fontweight='bold', pad=15)
ax2.set_ylim([0, 1.1])
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3, axis='y')
# Add value labels on bars
for i, (bar, val) in enumerate(zip(bars2, jaccard_values_aou)):
height = bar.get_height()
ax2.text(bar.get_x() + bar.get_width()/2., height + 0.02,
f'{val:.2f}', ha='center', va='bottom', fontsize=8)
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'jaccard_similarity_by_cluster.pdf',
dpi=300, bbox_inches='tight')
plt.savefig(OUTPUT_DIR / 'jaccard_similarity_by_cluster.png',
dpi=300, bbox_inches='tight')
print(f"✓ Saved bar plots to: {OUTPUT_DIR / 'jaccard_similarity_by_cluster.pdf'}")
plt.show()
✓ Saved bar plots to: jaccard_similarity_by_cluster.pdf
3. Distribution of Jaccard Similarities¶
# Create distribution plot
fig, ax = plt.subplots(figsize=(10, 6))
# Combine all Jaccard scores
all_jaccard = jaccard_mgb + jaccard_aou
# Create histogram with KDE
ax.hist(jaccard_mgb, bins=15, alpha=0.6, label=f'UKB ↔ MGB (n={len(jaccard_mgb)})',
color='steelblue', edgecolor='black', linewidth=1)
ax.hist(jaccard_aou, bins=15, alpha=0.6, label=f'UKB ↔ AoU (n={len(jaccard_aou)})',
color='darkgreen', edgecolor='black', linewidth=1)
# Add median lines
ax.axvline(x=median_mgb, color='steelblue', linestyle='--', linewidth=2,
label=f'MGB median = {median_mgb:.3f}')
ax.axvline(x=median_aou, color='darkgreen', linestyle='--', linewidth=2,
label=f'AoU median = {median_aou:.3f}')
ax.axvline(x=combined_median, color='red', linestyle='-', linewidth=2,
label=f'Combined median = {combined_median:.3f}')
ax.set_xlabel('Modified Jaccard Similarity', fontsize=11, fontweight='bold')
ax.set_ylabel('Number of Clusters', fontsize=11, fontweight='bold')
ax.set_title('Distribution of Modified Jaccard Similarities\nAcross All UKB Clusters',
fontsize=12, fontweight='bold', pad=15)
ax.legend(fontsize=10, loc='upper left')
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'jaccard_similarity_distribution.pdf',
dpi=300, bbox_inches='tight')
plt.savefig(OUTPUT_DIR / 'jaccard_similarity_distribution.png',
dpi=300, bbox_inches='tight')
print(f"✓ Saved distribution plot to: {OUTPUT_DIR / 'jaccard_similarity_distribution.pdf'}")
plt.show()
✓ Saved distribution plot to: jaccard_similarity_distribution.pdf
# Create summary table
summary_data = {
'Comparison': ['UKB ↔ MGB', 'UKB ↔ AoU', 'Combined'],
'Median Jaccard': [f'{median_mgb:.3f}', f'{median_aou:.3f}', f'{combined_median:.3f}'],
'Median %': [f'{median_mgb*100:.1f}%', f'{median_aou*100:.1f}%', f'{combined_median*100:.1f}%'],
'IQR': [
f"[{np.percentile(jaccard_mgb, 25):.3f}, {np.percentile(jaccard_mgb, 75):.3f}]",
f"[{np.percentile(jaccard_aou, 25):.3f}, {np.percentile(jaccard_aou, 75):.3f}]",
f"[{np.percentile(all_jaccard, 25):.3f}, {np.percentile(all_jaccard, 75):.3f}]"
],
'N Clusters': [len(jaccard_mgb), len(jaccard_aou), len(all_jaccard)],
'Common Diseases': [len(common_ukb_mgb), len(common_ukb_aou), len(common_all_three)]
}
summary_df = pd.DataFrame(summary_data)
print("="*80)
print("SUMMARY TABLE")
print("="*80)
print("\n", summary_df.to_string(index=False))
# Save to CSV
summary_df.to_csv(OUTPUT_DIR / 'cross_cohort_similarity_summary.csv', index=False)
print(f"\n✓ Saved summary table to: {OUTPUT_DIR / 'cross_cohort_similarity_summary.csv'}")
================================================================================ SUMMARY TABLE ================================================================================ Comparison Median Jaccard Median % IQR N Clusters Common Diseases UKB ↔ MGB 0.838 83.8% [0.663, 0.930] 20 346 UKB ↔ AoU 0.782 78.2% [0.667, 1.000] 20 348 Combined 0.800 80.0% [0.667, 0.964] 40 346 ✓ Saved summary table to: cross_cohort_similarity_summary.csv
Alternative Analysis: Psi-Based Similarity¶
Giovanni's concern: The initial cluster designation doesn't recognize the fit and treats all diseases equally (doesn't weight by cluster-specific affinity).
Solution: Use posterior psi (ψ) from fitted models, which captures the strength of disease-signature associations learned during training.
Approach:
- Load and average psi across batches for each cohort
- Compute similarity between cohorts based on psi matrices (using cosine similarity or correlation)
- Create heatmaps showing psi-based signature correspondence
- This weights diseases by their association strength with each signature, rather than treating all diseases equally
def calculate_weighted_modified_jaccard(ukb_psi, other_psi, ukb_clusters, other_clusters,
ukb_disease_names, other_disease_names,
common_diseases, biobank_name):
"""
Calculate weighted modified Jaccard similarity using cluster assignments weighted by psi coefficients.
For each UKB cluster k:
1. Get set of diseases in UKB cluster k (from common diseases)
2. For each cluster in other biobank, calculate weighted intersection
3. Weight by psi coefficients: stronger associations contribute more
4. Find best-matching cluster (maximum weighted intersection / weighted UKB cluster size)
Returns:
--------
weighted_jaccard_scores : list
List of weighted Jaccard scores for each UKB cluster
cluster_details : dict
Dictionary with details for each cluster
"""
# Create disease name to index mapping
ukb_disease_to_idx = {d: i for i, d in enumerate(ukb_disease_names)}
other_disease_to_idx = {d: i for i, d in enumerate(other_disease_names)}
# Create mapping: common disease name -> (ukb_idx, other_idx)
disease_mapping = {}
for d in common_diseases:
if d in ukb_disease_to_idx and d in other_disease_to_idx:
disease_mapping[d] = (ukb_disease_to_idx[d], other_disease_to_idx[d])
weighted_jaccard_scores = []
cluster_details = {}
# Get unique UKB clusters
ukb_unique_clusters = sorted(set(ukb_clusters))
for ukb_sig in ukb_unique_clusters:
# Get diseases in UKB signature (only common diseases)
ukb_sig_diseases = []
for d, (ukb_idx, other_idx) in disease_mapping.items():
if ukb_clusters[ukb_idx] == ukb_sig:
ukb_sig_diseases.append((d, ukb_idx, other_idx))
if len(ukb_sig_diseases) == 0:
continue
# Calculate weighted size of UKB signature (sum of psi values)
ukb_weighted_size = sum(ukb_psi[ukb_sig, ukb_idx] for d, ukb_idx, other_idx in ukb_sig_diseases)
if ukb_weighted_size == 0:
continue
# Find best matching cluster in other biobank
best_match_score = 0
best_match_cluster = None
best_weighted_overlap = 0
# Get unique clusters in other biobank
other_unique_clusters = sorted(set(other_clusters))
for other_sig in other_unique_clusters:
# Calculate weighted overlap: sum of psi values for diseases in both signatures
weighted_overlap = 0
weighted_ukb_only = ukb_weighted_size # Start with UKB signature weighted size
for d, ukb_idx, other_idx in ukb_sig_diseases:
if other_clusters[other_idx] == other_sig:
# Disease is in overlap: add both psi values
weighted_overlap += (ukb_psi[ukb_sig, ukb_idx] + other_psi[other_sig, other_idx])
weighted_ukb_only -= ukb_psi[ukb_sig, ukb_idx] # Remove from UKB-only count
# Add diseases only in other signature
weighted_other_only = 0
for d, (ukb_idx, other_idx) in disease_mapping.items():
if other_clusters[other_idx] == other_sig and ukb_clusters[ukb_idx] != ukb_sig:
weighted_other_only += other_psi[other_sig, other_idx]
# Weighted union = weighted_overlap + weighted_ukb_only + weighted_other_only
weighted_union = weighted_overlap + weighted_ukb_only + weighted_other_only
# Weighted Jaccard: weighted_overlap / weighted_union
if weighted_union > 0:
weighted_jaccard = weighted_overlap / weighted_union
else:
weighted_jaccard = 0.0
# Use intersection over UKB size (like original modified Jaccard)
# Weighted overlap divided by 2*ukb_weighted_size (since overlap has both psi values)
weighted_jaccard_over_ukb = weighted_overlap / (2 * ukb_weighted_size) if ukb_weighted_size > 0 else 0.0
# Use the intersection/UKB version to match original modified Jaccard logic
if weighted_jaccard_over_ukb > best_match_score:
best_match_score = weighted_jaccard_over_ukb
best_match_cluster = other_sig
best_weighted_overlap = weighted_overlap
weighted_jaccard_scores.append(best_match_score)
cluster_details[ukb_sig] = {
'weighted_jaccard': best_match_score,
'best_match': best_match_cluster,
'weighted_overlap': best_weighted_overlap,
'ukb_weighted_size': ukb_weighted_size
}
return weighted_jaccard_scores, cluster_details
%run /Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/new_notebooks/reviewer_responses/notebooks/R3/giovanni_weighted_coefficient_overlap.py
================================================================================
GIOVANNI'S WEIGHTED COEFFICIENT OVERLAP ANALYSIS
================================================================================
1. Loading UKB psi (pooled from batches)...
Found 40 files matching pattern: enrollment_model_W0.0001_batch_*_*.pt
✓ Pooled 40 batches: shape (21, 348)
UKB Psi Summary:
Shape: (21, 348)
Min: -5.2067
Max: 1.8672
Mean: -2.1014
Median: -2.0473
Std: 0.8253
2. Loading AoU psi (pooled from batches)...
Found 25 files matching pattern: aou_model_batch_*.pt
✓ Pooled 25 batches: shape (21, 348)
AoU Psi Summary:
Shape: (21, 348)
Min: -5.0487
Max: 2.8246
Mean: -2.0600
Median: -2.0262
Std: 0.9735
3. Loading MGB psi (from checkpoint)...
MGB Psi Summary:
Shape: (21, 346)
Min: -5.1884
Max: 3.4834
Mean: -2.0386
Median: -2.0432
Std: 1.0410
4a. Loading UKB phi (pooled from batches, time-averaged)...
Found 40 files matching pattern: enrollment_model_W0.0001_batch_*_*.pt
✓ Pooled 40 batches: phi shape (21, 348, 52), time-averaged shape (21, 348)
4b. Loading AoU phi (pooled from batches, time-averaged)...
Found 25 files matching pattern: aou_model_batch_*.pt
✓ Pooled 25 batches: phi shape (21, 348, 52), time-averaged shape (21, 348)
4c. Loading MGB phi (from checkpoint, time-averaged)...
5. Loading disease names...
6. Loading initial clusters (old approach)...
✓ Loaded initial UKB clusters
✓ Loaded initial MGB and AoU clusters
7. Computing clusters from posterior psi (max signature per disease)...
✓ UKB: Computed posterior clusters from averaged psi
✓ MGB: Computed posterior clusters from psi
✓ AoU: Computed posterior clusters from averaged psi
UKB: 20 signatures
MGB: 20 signatures
AoU: 20 signatures
================================================================================
BATCH CONSISTENCY STATISTICS
================================================================================
UKB: Computing batch consistency...
✓ Computed batch consistency for 348 diseases
Median consistency: 100.0%
Mean consistency: 99.9%
Range: [92.5%, 100.0%]
Diseases with <100% consistency (6 total):
Tobacco use disorder: 92.5% (37/40 batches)
Type 2 diabetes: 95.0% (38/40 batches)
Hyposmolality and/or hyponatremia: 97.5% (39/40 batches)
Respiratory failure: 97.5% (39/40 batches)
Chronic renal failure [CKD]: 97.5% (39/40 batches)
Disorders of uterus, NEC: 97.5% (39/40 batches)
UKB batch consistency by signature:
Pct_Consistent
mean median min max count
Final_Sig
0 100.0 100.0 100.0 100.0 16
1 100.0 100.0 100.0 100.0 21
2 100.0 100.0 100.0 100.0 15
3 100.0 100.0 100.0 100.0 82
4 100.0 100.0 100.0 100.0 5
5 100.0 100.0 100.0 100.0 7
6 100.0 100.0 100.0 100.0 8
7 100.0 100.0 100.0 100.0 22
8 99.9 100.0 97.5 100.0 28
9 100.0 100.0 100.0 100.0 12
10 100.0 100.0 100.0 100.0 11
11 100.0 100.0 100.0 100.0 8
12 100.0 100.0 100.0 100.0 7
13 100.0 100.0 100.0 100.0 13
14 99.0 100.0 92.5 100.0 10
15 99.0 100.0 95.0 100.0 5
16 99.9 100.0 97.5 100.0 29
17 100.0 100.0 100.0 100.0 17
18 100.0 100.0 100.0 100.0 9
19 99.9 100.0 97.5 100.0 23
AoU: Computing batch consistency...
✓ Computed batch consistency for 348 diseases
Median consistency: 100.0%
Mean consistency: 99.9%
Range: [92.0%, 100.0%]
Diseases with <100% consistency (8 total):
Sleep disorders: 92.0% (23/25 batches)
Uterine leiomyoma: 96.0% (24/25 batches)
Other anemias: 96.0% (24/25 batches)
Inflammation of eyelids: 96.0% (24/25 batches)
Otitis media: 96.0% (24/25 batches)
Pain and other symptoms associated with female genital organs: 96.0% (24/25 batches)
Abdominal pain: 96.0% (24/25 batches)
Atopic/contact dermatitis due to other or unspecified: 96.0% (24/25 batches)
AoU batch consistency by signature:
Pct_Consistent
mean median min max count
Final_Sig
0 100.0 100.0 100.0 100.0 4
1 99.8 100.0 96.0 100.0 38
2 100.0 100.0 100.0 100.0 10
3 99.3 100.0 92.0 100.0 18
4 100.0 100.0 100.0 100.0 8
5 100.0 100.0 100.0 100.0 67
6 100.0 100.0 100.0 100.0 20
7 99.7 100.0 96.0 100.0 23
8 99.8 100.0 96.0 100.0 19
9 100.0 100.0 100.0 100.0 17
10 100.0 100.0 100.0 100.0 6
11 100.0 100.0 100.0 100.0 8
12 100.0 100.0 100.0 100.0 36
13 99.7 100.0 96.0 100.0 13
14 100.0 100.0 100.0 100.0 13
15 100.0 100.0 100.0 100.0 8
16 100.0 100.0 100.0 100.0 21
17 100.0 100.0 100.0 100.0 4
18 100.0 100.0 100.0 100.0 3
19 100.0 100.0 100.0 100.0 12
UKB: Comparing initial vs posterior cluster assignments...
✓ Compared 348 diseases
Changed signatures: 0 / 348 (0.0%)
Unchanged: 348 (100.0%)
AoU: Comparing initial vs posterior cluster assignments...
✓ Compared 348 diseases
Changed signatures: 0 / 348 (0.0%)
Unchanged: 348 (100.0%)
MGB: Comparing initial vs posterior cluster assignments...
✓ Compared 346 diseases
Changed signatures: 0 / 346 (0.0%)
Unchanged: 346 (100.0%)
MGB changes by initial signature:
N_Changed N_Total Pct_Changed
Initial_Sig
0 0 6 0.0
1 0 28 0.0
2 0 24 0.0
3 0 11 0.0
4 0 24 0.0
5 0 20 0.0
6 0 15 0.0
7 0 16 0.0
8 0 19 0.0
9 0 13 0.0
10 0 7 0.0
11 0 7 0.0
12 0 68 0.0
13 0 7 0.0
14 0 13 0.0
15 0 12 0.0
16 0 16 0.0
17 0 8 0.0
18 0 12 0.0
19 0 20 0.0
8. Finding common diseases...
Common UKB-MGB: 346 diseases
Common UKB-AoU: 348 diseases
================================================================================
BINARY JACCARD USING INITIAL CLUSTERS (OLD APPROACH)
================================================================================
UKB ↔ MGB:
Median Jaccard: 0.8375
Range: [0.3913, 1.0000]
UKB ↔ AoU:
Median Jaccard: 0.7824
Range: [0.1818, 1.0000]
================================================================================
EXPERIMENT 1: BINARY JACCARD USING POSTERIOR CLUSTERS
================================================================================
UKB ↔ MGB:
Median Jaccard: 0.8375
Range: [0.3913, 1.0000]
UKB ↔ AoU:
Median Jaccard: 0.7824
Range: [0.1818, 1.0000]
Plotting binary Jaccard heatmaps (posterior clusters)...
✓ Saved heatmap to: /Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/new_notebooks/reviewer_responses/notebooks/R3/binary_jaccard_posterior_heatmaps.pdf
Plotting binary Jaccard heatmaps (initial vs posterior)...
✓ Saved comparison heatmaps to: /Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/new_notebooks/reviewer_responses/notebooks/R3/binary_jaccard_comparison_heatmaps.pdf
================================================================================
EXPERIMENT 2: WEIGHTED OVERLAP USING NORMALIZED PSI (SIGMOID)
================================================================================
Formula:
Weighted Overlap(UKB_sig_k, Other_sig_k') =
Σ_{d ∈ overlap} sigmoid(ψ_{k,d}) / Σ_{d ∈ UKB_sig_k} sigmoid(ψ_{k,d})
where:
- overlap = diseases in both UKB signature k and Other signature k'
- UKB_sig_k = all diseases assigned to UKB signature k
- sigmoid(ψ_{k,d}) = expit(ψ_{k,d}) = 1 / (1 + exp(-ψ_{k,d}))
- This metric is bounded to [0, 1]
UKB ↔ MGB:
Similarity matrix shape: (21, 21)
Range: [0.0000, 1.0000]
Median best match: 0.8005
UKB ↔ AoU:
Similarity matrix shape: (21, 21)
Range: [0.0000, 1.0000]
Median best match: 0.7159
================================================================================
ORIGINAL: WEIGHTED OVERLAP USING RAW PSI
================================================================================
UKB ↔ MGB:
UKB ↔ MGB:
Similarity matrix shape: (21, 21)
Range: [-0.5310, 1.2985]
Median best match: 0.8005
UKB ↔ AoU:
Similarity matrix shape: (21, 21)
Range: [-1.1193, 2.1193]
Median best match: 0.6843
================================================================================
EXPERIMENT 3: WEIGHTED OVERLAP USING TIME-AVERAGED PHI (SIGMOID)
================================================================================
Formula:
Weighted Overlap(UKB_sig_k, Other_sig_k') =
Σ_{d ∈ overlap} sigmoid(φ̄_{k,d}) / Σ_{d ∈ UKB_sig_k} sigmoid(φ̄_{k,d})
where:
- overlap = diseases in both UKB signature k and Other signature k'
- UKB_sig_k = all diseases assigned to UKB signature k
- φ̄_{k,d} = mean_t(φ_{k,d,t}) (time-averaged phi)
- sigmoid(φ̄_{k,d}) = expit(φ̄_{k,d}) = 1 / (1 + exp(-φ̄_{k,d}))
- Cluster assignments use argmax(psi), but weights use phi
- This metric is bounded to [0, 1]
UKB ↔ MGB:
Similarity matrix shape: (21, 21)
Range: [0.0000, 1.0000]
Median best match: 0.7821
UKB ↔ AoU:
Similarity matrix shape: (21, 21)
Range: [0.0000, 1.0000]
Median best match: 0.7763
9a. Plotting heatmaps (normalized psi version)...
✓ Saved heatmap to: /Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/new_notebooks/reviewer_responses/notebooks/R3/giovanni_weighted_overlap_psi_heatmaps.pdf
9b. Plotting heatmaps (time-averaged phi version)... ✓ Saved heatmap to: /Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/new_notebooks/reviewer_responses/notebooks/R3/giovanni_weighted_overlap_phi_heatmaps.pdf
================================================================================ ANALYSIS COMPLETE ================================================================================
# Load and pool psi from batches for each cohort
from scipy.spatial.distance import cosine
from scipy.stats import pearsonr
import glob
def pool_psi_from_batches(batch_dir, pattern, max_batches=None):
"""
Load and pool psi from all batch files.
Args:
batch_dir: Directory containing batch files
pattern: Glob pattern for batch files (e.g., "aou_model_batch_*.pt")
max_batches: Maximum number of batches to load (None = all)
Returns:
Pooled psi (mean across batches) as numpy array, shape (K, D)
"""
batch_dir = Path(batch_dir)
all_psis = []
# Find all matching files
files = sorted(glob.glob(str(batch_dir / pattern)))
print(f" Found {len(files)} files matching pattern: {pattern}")
if max_batches is not None:
files = files[:max_batches]
for file_path in files:
try:
checkpoint = torch.load(file_path, map_location='cpu', weights_only=False)
# Extract psi
if 'model_state_dict' in checkpoint and 'psi' in checkpoint['model_state_dict']:
psi = checkpoint['model_state_dict']['psi']
elif 'psi' in checkpoint:
psi = checkpoint['psi']
else:
print(f" Warning: No psi found in {Path(file_path).name}")
continue
# Convert to numpy if tensor
if torch.is_tensor(psi):
psi = psi.detach().cpu().numpy()
all_psis.append(psi)
print(f" Loaded psi from {Path(file_path).name}, shape: {psi.shape}")
except Exception as e:
print(f" Error loading {Path(file_path).name}: {e}")
continue
if len(all_psis) == 0:
raise ValueError(f"No psi arrays loaded from {batch_dir / pattern}")
# Stack and compute mean
psi_stack = np.stack(all_psis, axis=0) # (n_batches, K, D)
psi_mean = np.mean(psi_stack, axis=0) # (K, D)
psi_std = np.std(psi_stack, axis=0) # (K, D)
print(f" ✓ Pooled {len(all_psis)} batches")
print(f" ✓ Psi shape: {psi_mean.shape}")
print(f" ✓ Psi range: [{psi_mean.min():.4f}, {psi_mean.max():.4f}]")
print(f" ✓ Psi mean: {psi_mean.mean():.4f}, std (across batches): {psi_std.mean():.4f}")
return psi_mean, psi_std
def load_psi_from_checkpoint(checkpoint_path):
"""Load psi from a single checkpoint file."""
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
if 'model_state_dict' in checkpoint and 'psi' in checkpoint['model_state_dict']:
psi = checkpoint['model_state_dict']['psi']
elif 'psi' in checkpoint:
psi = checkpoint['psi']
else:
raise ValueError(f"No psi found in {checkpoint_path}")
if torch.is_tensor(psi):
psi = psi.detach().cpu().numpy()
print(f" ✓ Loaded psi from {Path(checkpoint_path).name}, shape: {psi.shape}")
return psi
print("="*80)
print("LOADING AND POOLING PSI FROM COHORTS")
print("="*80)
# Load UKB psi from batches
print("\n1. Loading UKB psi from batches...")
ukb_batch_dir = '/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_retrospective_full'
ukb_pattern = 'enrollment_model_W0.0001_batch_*_*.pt'
try:
# Load all batches (max_batches=None loads all found files)
ukb_psi, ukb_psi_std = pool_psi_from_batches(ukb_batch_dir, ukb_pattern, max_batches=None)
print(f" ✓ UKB: Psi shape {ukb_psi.shape}, {ukb_psi.shape[0]} signatures, {ukb_psi.shape[1]} diseases")
except Exception as e:
print(f" ✗ Error loading UKB psi: {e}")
ukb_psi = None
# Load AoU psi from batches
print("\n2. Loading AoU psi from batches...")
aou_batch_dir = '/Users/sarahurbut/Library/CloudStorage/Dropbox/aou_batches'
aou_pattern = 'aou_model_batch_*.pt'
try:
aou_psi, aou_psi_std = pool_psi_from_batches(aou_batch_dir, aou_pattern, max_batches=None)
print(f" ✓ AoU: Psi shape {aou_psi.shape}, {aou_psi.shape[0]} signatures, {aou_psi.shape[1]} diseases")
except Exception as e:
print(f" ✗ Error loading AoU psi: {e}")
aou_psi = None
# Load MGB psi from checkpoint
print("\n3. Loading MGB psi from checkpoint...")
mgb_psi_path = '/Users/sarahurbut/aladynoulli2/mgb_model_initialized.pt'
try:
if Path(mgb_psi_path).exists():
mgb_psi = load_psi_from_checkpoint(mgb_psi_path)
print(f" ✓ MGB: Psi shape {mgb_psi.shape}, {mgb_psi.shape[0]} signatures, {mgb_psi.shape[1]} diseases")
else:
print(f" ⚠ MGB checkpoint not found at {mgb_psi_path}")
print(f" Trying alternative: mgb_model_with_kappa_bigam_MGB.pt")
mgb_alt_path = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/model_with_kappa_bigam_MGB.pt'
if Path(mgb_alt_path).exists():
mgb_psi = load_psi_from_checkpoint(mgb_alt_path)
print(f" ✓ MGB: Psi shape {mgb_psi.shape}, {mgb_psi.shape[0]} signatures, {mgb_psi.shape[1]} diseases")
else:
mgb_psi = None
except Exception as e:
print(f" ✗ Error loading MGB psi: {e}")
mgb_psi = None
print("\n" + "="*80)
print("PSI LOADING COMPLETE")
print("="*80)
================================================================================
LOADING AND POOLING PSI FROM COHORTS
================================================================================
1. Loading UKB psi from batches...
Found 40 files matching pattern: enrollment_model_W0.0001_batch_*_*.pt
Loaded psi from enrollment_model_W0.0001_batch_0_10000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_100000_110000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_10000_20000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_110000_120000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_120000_130000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_130000_140000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_140000_150000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_150000_160000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_160000_170000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_170000_180000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_180000_190000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_190000_200000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_200000_210000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_20000_30000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_210000_220000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_220000_230000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_230000_240000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_240000_250000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_250000_260000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_260000_270000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_270000_280000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_280000_290000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_290000_300000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_300000_310000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_30000_40000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_310000_320000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_320000_330000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_330000_340000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_340000_350000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_350000_360000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_360000_370000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_370000_380000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_380000_390000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_390000_400000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_40000_50000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_50000_60000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_60000_70000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_70000_80000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_80000_90000.pt, shape: (21, 348)
Loaded psi from enrollment_model_W0.0001_batch_90000_100000.pt, shape: (21, 348)
✓ Pooled 40 batches
✓ Psi shape: (21, 348)
✓ Psi range: [-5.2067, 1.8672]
✓ Psi mean: -2.1014, std (across batches): 0.0783
✓ UKB: Psi shape (21, 348), 21 signatures, 348 diseases
2. Loading AoU psi from batches...
Found 25 files matching pattern: aou_model_batch_*.pt
Loaded psi from aou_model_batch_0_0_10000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_10_100000_110000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_11_110000_120000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_12_120000_130000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_13_130000_140000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_14_140000_150000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_15_150000_160000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_16_160000_170000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_17_170000_180000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_18_180000_190000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_19_190000_200000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_1_10000_20000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_20_200000_210000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_21_210000_220000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_22_220000_230000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_23_230000_240000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_24_240000_243303.pt, shape: (21, 348)
Loaded psi from aou_model_batch_2_20000_30000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_3_30000_40000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_4_40000_50000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_5_50000_60000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_6_60000_70000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_7_70000_80000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_8_80000_90000.pt, shape: (21, 348)
Loaded psi from aou_model_batch_9_90000_100000.pt, shape: (21, 348)
✓ Pooled 25 batches
✓ Psi shape: (21, 348)
✓ Psi range: [-5.0487, 2.8246]
✓ Psi mean: -2.0600, std (across batches): 0.1713
✓ AoU: Psi shape (21, 348), 21 signatures, 348 diseases
3. Loading MGB psi from checkpoint...
✓ Loaded psi from mgb_model_initialized.pt, shape: (21, 346)
✓ MGB: Psi shape (21, 346), 21 signatures, 346 diseases
================================================================================
PSI LOADING COMPLETE
================================================================================
# Define function to compute similarity matrices based on psi
# For each signature k, compare the psi vectors (across diseases) between cohorts
# Using cosine similarity: similarity = 1 - cosine_distance
# Note: Similarity will be computed AFTER disease alignment in next cell
def compute_psi_similarity_matrix(psi1, psi2, disease_mask=None, method='cosine'):
"""
Compute similarity matrix between two psi matrices.
Compares disease-signature associations (psi vectors) across cohorts.
Each entry (k1, k2) represents how similar signature k1 in cohort 1 is to signature k2 in cohort 2,
weighted by the strength of disease associations in both cohorts.
Args:
psi1: Psi matrix from cohort 1, shape (K1, D1) - disease associations for each signature
psi2: Psi matrix from cohort 2, shape (K2, D2) - disease associations for each signature
disease_mask: Boolean mask for common diseases (optional, if diseases are already aligned)
method: 'cosine' or 'correlation'
Returns:
similarity_matrix: (K1, K2) matrix of similarities (0 to 1 scale)
"""
K1, D1 = psi1.shape
K2, D2 = psi2.shape
# Handle disease masking if provided (should already be aligned, but just in case)
if disease_mask is not None:
if len(disease_mask) == D1:
psi1_masked = psi1[:, disease_mask]
psi2_masked = psi2[:, disease_mask]
elif len(disease_mask) == D2:
psi1_masked = psi1[:, disease_mask]
psi2_masked = psi2[:, disease_mask]
else:
raise ValueError(f"Disease mask length {len(disease_mask)} doesn't match D1={D1} or D2={D2}")
# Check if same number of diseases after masking
if psi1_masked.shape[1] != psi2_masked.shape[1]:
raise ValueError(f"After masking: psi1 has {psi1_masked.shape[1]} diseases, psi2 has {psi2_masked.shape[1]} diseases")
psi1 = psi1_masked
psi2 = psi2_masked
similarity_matrix = np.zeros((K1, K2))
for k1 in range(K1):
vec1 = psi1[k1, :] # Disease associations for signature k1 in cohort 1
for k2 in range(K2):
vec2 = psi2[k2, :] # Disease associations for signature k2 in cohort 2
if method == 'cosine':
# Cosine similarity: normalized dot product (range: -1 to 1, scaled to 0 to 1)
# Handle zero vectors
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
similarity = 0.0
else:
cosine_sim = np.dot(vec1, vec2) / (norm1 * norm2)
similarity = (cosine_sim + 1) / 2 # Scale from [-1, 1] to [0, 1] for interpretability
elif method == 'correlation':
# Pearson correlation
if np.std(vec1) == 0 or np.std(vec2) == 0:
similarity = 0.0
else:
corr, _ = pearsonr(vec1, vec2)
similarity = (corr + 1) / 2 # Scale from [-1, 1] to [0, 1]
else:
raise ValueError(f"Unknown method: {method}")
similarity_matrix[k1, k2] = similarity
return similarity_matrix
print("="*80)
print("DEFINED PSI SIMILARITY FUNCTION")
print("="*80)
print("\n✓ Function ready for computing similarity matrices")
print(" Similarity will be computed AFTER disease alignment (next cell)")
print("="*80)
================================================================================ DEFINED PSI SIMILARITY FUNCTION ================================================================================ ✓ Function ready for computing similarity matrices Similarity will be computed AFTER disease alignment (next cell) ================================================================================
# Align diseases across cohorts by name
# This ensures we're comparing the same diseases when computing similarity
def align_psi_by_disease_names(psi1, disease_names1, psi2, disease_names2):
"""
Align two psi matrices by disease names, returning aligned psi matrices and mask.
Args:
psi1: Psi matrix from cohort 1, shape (K1, D1)
disease_names1: List of disease names for cohort 1, length D1
psi2: Psi matrix from cohort 2, shape (K2, D2)
disease_names2: List of disease names for cohort 2, length D2
Returns:
psi1_aligned: Aligned psi1, shape (K1, D_common)
psi2_aligned: Aligned psi2, shape (K2, D_common)
common_diseases: List of common disease names
mask1: Boolean mask for cohort 1 diseases (which are common)
mask2: Boolean mask for cohort 2 diseases (which are common)
"""
# Convert to lists if needed
if isinstance(disease_names1, np.ndarray):
disease_names1 = disease_names1.tolist()
if isinstance(disease_names2, np.ndarray):
disease_names2 = disease_names2.tolist()
# Find common diseases
set1 = set(disease_names1)
set2 = set(disease_names2)
common_diseases = sorted(list(set1 & set2))
print(f" Common diseases: {len(common_diseases)} (out of {len(disease_names1)} and {len(disease_names2)})")
# Create indices for common diseases
indices1 = [disease_names1.index(d) for d in common_diseases]
indices2 = [disease_names2.index(d) for d in common_diseases]
# Extract aligned psi matrices
psi1_aligned = psi1[:, indices1]
psi2_aligned = psi2[:, indices2]
# Create masks
mask1 = np.array([d in common_diseases for d in disease_names1])
mask2 = np.array([d in common_diseases for d in disease_names2])
return psi1_aligned, psi2_aligned, common_diseases, mask1, mask2
# Load disease names from checkpoints
print("="*80)
print("LOADING DISEASE NAMES AND ALIGNING PSI MATRICES")
print("="*80)
# Get disease names from UKB checkpoint
ukb_disease_names = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/disease_names.csv')['x'].tolist()
print(f"\n✓ UKB: {len(ukb_disease_names)} diseases")
# Get disease names from MGB checkpoint
mgb_checkpoint_path = '/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/model_with_kappa_bigam_MGB.pt'
if Path(mgb_checkpoint_path).exists():
mgb_checkpoint = torch.load(mgb_checkpoint_path, map_location='cpu', weights_only=False)
if 'disease_names' in mgb_checkpoint:
mgb_disease_names = mgb_checkpoint['disease_names']
if isinstance(mgb_disease_names, np.ndarray):
mgb_disease_names = mgb_disease_names.tolist()
elif isinstance(mgb_disease_names, pd.Series):
mgb_disease_names = mgb_disease_names.tolist()
else:
mgb_disease_names = None
else:
mgb_disease_names = None
# Get disease names from AoU checkpoint - might need to load from a batch file
aou_sample_batch = '/Users/sarahurbut/Library/CloudStorage/Dropbox/aou_batches/aou_model_batch_0_0_10000.pt'
if Path(aou_sample_batch).exists():
aou_sample = torch.load(aou_sample_batch, map_location='cpu', weights_only=False)
if 'disease_names' in aou_sample:
aou_disease_names = aou_sample['disease_names']
if isinstance(aou_disease_names, np.ndarray):
aou_disease_names = aou_disease_names.tolist()
elif isinstance(aou_disease_names, pd.Series):
aou_disease_names = aou_disease_names.tolist()
else:
aou_disease_names = None
else:
aou_disease_names = None
# Initialize similarity matrices
ukb_mgb_sim = None
ukb_aou_sim = None
# Align psi matrices by disease names and compute similarity
if ukb_psi is not None and mgb_psi is not None and mgb_disease_names is not None:
print("\n1. Aligning UKB ↔ MGB by disease names...")
try:
ukb_psi_aligned, mgb_psi_aligned, common_ukb_mgb, _, _ = align_psi_by_disease_names(
ukb_psi, ukb_disease_names, mgb_psi, mgb_disease_names
)
# Compute similarity with aligned diseases
ukb_mgb_sim = compute_psi_similarity_matrix(ukb_psi_aligned, mgb_psi_aligned, method='cosine')
print(f" ✓ Aligned similarity matrix shape: {ukb_mgb_sim.shape}")
print(f" ✓ Similarity range: [{ukb_mgb_sim.min():.3f}, {ukb_mgb_sim.max():.3f}]")
print(f" ✓ Mean similarity: {ukb_mgb_sim.mean():.3f}")
except Exception as e:
print(f" ✗ Error aligning/computing similarity: {e}")
common_ukb_mgb = None
ukb_mgb_sim = None
else:
common_ukb_mgb = None
if ukb_psi is None:
print(" ⚠ Could not align UKB ↔ MGB (UKB psi not loaded)")
elif mgb_psi is None:
print(" ⚠ Could not align UKB ↔ MGB (MGB psi not loaded)")
elif mgb_disease_names is None:
print(" ⚠ Could not align UKB ↔ MGB (MGB disease names not found)")
if ukb_psi is not None and aou_psi is not None and aou_disease_names is not None:
print("\n2. Aligning UKB ↔ AoU by disease names...")
try:
ukb_psi_aligned_aou, aou_psi_aligned, common_ukb_aou, _, _ = align_psi_by_disease_names(
ukb_psi, ukb_disease_names, aou_psi, aou_disease_names
)
# Compute similarity with aligned diseases
ukb_aou_sim = compute_psi_similarity_matrix(ukb_psi_aligned_aou, aou_psi_aligned, method='cosine')
print(f" ✓ Aligned similarity matrix shape: {ukb_aou_sim.shape}")
print(f" ✓ Similarity range: [{ukb_aou_sim.min():.3f}, {ukb_aou_sim.max():.3f}]")
print(f" ✓ Mean similarity: {ukb_aou_sim.mean():.3f}")
except Exception as e:
print(f" ✗ Error aligning/computing similarity: {e}")
common_ukb_aou = None
ukb_aou_sim = None
else:
common_ukb_aou = None
if ukb_psi is None:
print(" ⚠ Could not align UKB ↔ AoU (UKB psi not loaded)")
elif aou_psi is None:
print(" ⚠ Could not align UKB ↔ AoU (AoU psi not loaded)")
elif aou_disease_names is None:
print(" ⚠ Could not align UKB ↔ AoU (AoU disease names not found)")
print("\n" + "="*80)
print("PSI ALIGNMENT COMPLETE")
print("="*80)
================================================================================ LOADING DISEASE NAMES AND ALIGNING PSI MATRICES ================================================================================ ✓ UKB: 348 diseases 1. Aligning UKB ↔ MGB by disease names... Common diseases: 346 (out of 348 and 346) ✓ Aligned similarity matrix shape: (21, 21) ✓ Similarity range: [0.847, 1.000] ✓ Mean similarity: 0.946 2. Aligning UKB ↔ AoU by disease names... Common diseases: 348 (out of 348 and 348) ✓ Aligned similarity matrix shape: (21, 21) ✓ Similarity range: [0.854, 1.000] ✓ Mean similarity: 0.953 ================================================================================ PSI ALIGNMENT COMPLETE ================================================================================
# Find best matches and compute summary statistics
# For each UKB signature, find the best-matching signature in the other cohort
# Compute weighted similarity metrics based on psi strengths
def find_best_psi_matches(similarity_matrix, cohort1_name='UKB', cohort2_name='Other'):
"""
Find best-matching signatures between two cohorts based on psi similarity.
Args:
similarity_matrix: (K1, K2) similarity matrix
cohort1_name: Name of first cohort
cohort2_name: Name of second cohort
Returns:
best_matches: DataFrame with best matches and similarity scores
summary_stats: Dictionary with summary statistics
"""
K1, K2 = similarity_matrix.shape
best_matches = []
for k1 in range(K1):
best_k2 = np.argmax(similarity_matrix[k1, :])
best_sim = similarity_matrix[k1, best_k2]
best_matches.append({
f'{cohort1_name}': k1,
f'{cohort2_name}': int(best_k2),
'Similarity': best_sim,
'Mean_Similarity': similarity_matrix[k1, :].mean(),
'Std_Similarity': similarity_matrix[k1, :].std()
})
best_matches_df = pd.DataFrame(best_matches)
# Summary statistics
summary_stats = {
'mean_best_match': best_matches_df['Similarity'].mean(),
'median_best_match': best_matches_df['Similarity'].median(),
'std_best_match': best_matches_df['Similarity'].std(),
'min_best_match': best_matches_df['Similarity'].min(),
'max_best_match': best_matches_df['Similarity'].max(),
'q25': best_matches_df['Similarity'].quantile(0.25),
'q75': best_matches_df['Similarity'].quantile(0.75),
}
return best_matches_df, summary_stats
print("="*80)
print("FINDING BEST PSI-BASED MATCHES")
print("="*80)
if ukb_mgb_sim is not None:
print("\n1. UKB ↔ MGB best matches...")
best_matches_ukb_mgb_psi, stats_ukb_mgb = find_best_psi_matches(ukb_mgb_sim, 'UKB', 'MGB')
print(f" Median similarity: {stats_ukb_mgb['median_best_match']:.3f}")
print(f" Mean similarity: {stats_ukb_mgb['mean_best_match']:.3f}")
print(f" Range: [{stats_ukb_mgb['min_best_match']:.3f}, {stats_ukb_mgb['max_best_match']:.3f}]")
else:
best_matches_ukb_mgb_psi = None
stats_ukb_mgb = None
if ukb_aou_sim is not None:
print("\n2. UKB ↔ AoU best matches...")
best_matches_ukb_aou_psi, stats_ukb_aou = find_best_psi_matches(ukb_aou_sim, 'UKB', 'AoU')
print(f" Median similarity: {stats_ukb_aou['median_best_match']:.3f}")
print(f" Mean similarity: {stats_ukb_aou['mean_best_match']:.3f}")
print(f" Range: [{stats_ukb_aou['min_best_match']:.3f}, {stats_ukb_aou['max_best_match']:.3f}]")
else:
best_matches_ukb_aou_psi = None
stats_ukb_aou = None
print("\n" + "="*80)
print("BEST MATCHES IDENTIFIED")
print("="*80)
================================================================================ FINDING BEST PSI-BASED MATCHES ================================================================================ 1. UKB ↔ MGB best matches... Median similarity: 0.989 Mean similarity: 0.985 Range: [0.918, 1.000] 2. UKB ↔ AoU best matches... Median similarity: 0.989 Mean similarity: 0.987 Range: [0.921, 1.000] ================================================================================ BEST MATCHES IDENTIFIED ================================================================================
# Create heatmaps for psi-based similarity
# Similar style to original cluster-based heatmaps (clean, diagonal-focused)
def reorder_by_similarity_strength(sim_matrix, best_matches_df):
"""
Reorder signatures to put high-similarity matches together.
Sort UKB signatures by their best-match similarity (descending).
Sort other cohort signatures to match UKB order where possible.
"""
# Sort UKB signatures by similarity strength
ukb_order = best_matches_df.sort_values('Similarity', ascending=False)['UKB'].values.tolist()
# Find the name of the other cohort column (MGB or AoU)
other_cohort_cols = [col for col in best_matches_df.columns if col not in ['UKB', 'Similarity', 'Mean_Similarity', 'Std_Similarity']]
if len(other_cohort_cols) == 0:
# Fallback: assume it's the second column
other_cohort_col = [col for col in best_matches_df.columns if col != 'UKB'][0]
else:
other_cohort_col = other_cohort_cols[0]
# For the other cohort, order by which ones match to high-similarity UKB sigs first
other_order = []
seen = set()
# First add signatures that match to high-similarity UKB signatures
for ukb_sig in ukb_order:
match_row = best_matches_df[best_matches_df['UKB'] == ukb_sig]
if len(match_row) > 0:
other_sig = match_row.iloc[0][other_cohort_col]
if other_sig not in seen:
other_order.append(other_sig)
seen.add(other_sig)
# Add any remaining signatures
all_other = list(range(sim_matrix.shape[1]))
for sig in all_other:
if sig not in seen:
other_order.append(sig)
return ukb_order, other_order
print("="*80)
print("CREATING PSI-BASED SIMILARITY HEATMAPS")
print("="*80)
fig, axes = plt.subplots(1, 2, figsize=(16, 8))
if ukb_mgb_sim is not None and best_matches_ukb_mgb_psi is not None:
# Reorder for better visualization
ukb_order, mgb_order = reorder_by_similarity_strength(ukb_mgb_sim, best_matches_ukb_mgb_psi)
# Convert to integer arrays for indexing
ukb_order = np.array(ukb_order, dtype=int)
mgb_order = np.array(mgb_order, dtype=int)
# Create ordered similarity matrix (use numpy fancy indexing)
ukb_mgb_sim_array = np.array(ukb_mgb_sim) if not isinstance(ukb_mgb_sim, np.ndarray) else ukb_mgb_sim
ukb_mgb_sim_ordered = ukb_mgb_sim_array[np.ix_(ukb_order, mgb_order)]
# Convert to DataFrame for seaborn heatmap (like original)
ukb_mgb_sim_df = pd.DataFrame(
ukb_mgb_sim_ordered,
index=[k for k in ukb_order],
columns=[k for k in mgb_order]
)
# Create heatmap using seaborn (like original cluster-based heatmaps)
sns.heatmap(
ukb_mgb_sim_df,
cmap='Reds',
vmin=0, vmax=1,
ax=axes[0],
cbar_kws={'label': 'Psi Similarity (cosine)'},
linewidths=0.5,
linecolor='gray',
annot=False,
fmt='.2f'
)
axes[0].set_title('Psi-Based Signature Similarity: UKB vs MGB\n(Weighted by Disease Associations)',
fontsize=12, fontweight='bold', pad=15)
axes[0].set_xlabel('MGB Signature', fontsize=11, fontweight='bold')
axes[0].set_ylabel('UKB Signature', fontsize=11, fontweight='bold')
else:
axes[0].text(0.5, 0.5, 'UKB ↔ MGB\nSimilarity unavailable',
ha='center', va='center', fontsize=12)
axes[0].set_title('UKB vs MGB', fontsize=12, fontweight='bold')
if ukb_aou_sim is not None and best_matches_ukb_aou_psi is not None:
# Reorder for better visualization
ukb_order_aou, aou_order = reorder_by_similarity_strength(ukb_aou_sim, best_matches_ukb_aou_psi)
# Convert to integer arrays for indexing
ukb_order_aou = np.array(ukb_order_aou, dtype=int)
aou_order = np.array(aou_order, dtype=int)
# Create ordered similarity matrix (use numpy fancy indexing)
ukb_aou_sim_array = np.array(ukb_aou_sim) if not isinstance(ukb_aou_sim, np.ndarray) else ukb_aou_sim
ukb_aou_sim_ordered = ukb_aou_sim_array[np.ix_(ukb_order_aou, aou_order)]
# Convert to DataFrame for seaborn heatmap (like original)
ukb_aou_sim_df = pd.DataFrame(
ukb_aou_sim_ordered,
index=[k for k in ukb_order_aou],
columns=[k for k in aou_order]
)
# Create heatmap using seaborn (like original cluster-based heatmaps)
sns.heatmap(
ukb_aou_sim_df,
cmap='Reds',
vmin=0, vmax=1,
ax=axes[1],
cbar_kws={'label': 'Psi Similarity (cosine)'},
linewidths=0.5,
linecolor='gray',
annot=False,
fmt='.2f'
)
axes[1].set_title('Psi-Based Signature Similarity: UKB vs AoU\n(Weighted by Disease Associations)',
fontsize=12, fontweight='bold', pad=15)
axes[1].set_xlabel('AoU Signature', fontsize=11, fontweight='bold')
axes[1].set_ylabel('UKB Signature', fontsize=11, fontweight='bold')
else:
axes[1].text(0.5, 0.5, 'UKB ↔ AoU\nSimilarity unavailable',
ha='center', va='center', fontsize=12)
axes[1].set_title('UKB vs AoU', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'psi_based_cross_cohort_similarity_heatmaps.pdf',
dpi=300, bbox_inches='tight')
plt.savefig(OUTPUT_DIR / 'psi_based_cross_cohort_similarity_heatmaps.png',
dpi=300, bbox_inches='tight')
print(f"✓ Saved psi-based similarity heatmaps to: {OUTPUT_DIR / 'psi_based_cross_cohort_similarity_heatmaps.pdf'}")
plt.show()
================================================================================ CREATING PSI-BASED SIMILARITY HEATMAPS ================================================================================ ✓ Saved psi-based similarity heatmaps to: psi_based_cross_cohort_similarity_heatmaps.pdf
# Compare psi-based similarity with cluster-based similarity
# This shows how using posterior psi (weighted by fit) differs from initial cluster assignments
# Check if cluster-based matches exist (from original analysis in earlier cells)
# Use try-except to safely check for variables that may not exist if earlier cells weren't run
try:
has_cluster_mgb = best_matches_mgb is not None
except NameError:
has_cluster_mgb = False
print("⚠ Note: Cluster-based matches (best_matches_mgb) not found. Run cells 8-10 first for comparison.")
try:
has_cluster_aou = best_matches_aou is not None
except NameError:
has_cluster_aou = False
print("⚠ Note: Cluster-based matches (best_matches_aou) not found. Run cells 8-10 first for comparison.")
if best_matches_ukb_mgb_psi is not None and has_cluster_mgb:
print("="*80)
print("COMPARISON: PSI-BASED vs CLUSTER-BASED MATCHES (UKB ↔ MGB)")
print("="*80)
# Merge best matches
comparison_mgb = pd.merge(
best_matches_ukb_mgb_psi[['UKB', 'MGB', 'Similarity']].rename(columns={'MGB': 'MGB_psi', 'Similarity': 'Psi_Similarity'}),
best_matches_mgb[['UKB', 'MGB', 'Proportion']].rename(columns={'MGB': 'MGB_cluster', 'Proportion': 'Cluster_Proportion'}),
on='UKB',
how='inner'
)
# Count matches where psi and cluster agree
matches_agree = (comparison_mgb['MGB_psi'] == comparison_mgb['MGB_cluster']).sum()
matches_total = len(comparison_mgb)
agreement_rate = matches_agree / matches_total if matches_total > 0 else 0
print(f"\nAgreement between psi-based and cluster-based matches: {matches_agree}/{matches_total} ({agreement_rate*100:.1f}%)")
print(f"\nMean psi similarity: {comparison_mgb['Psi_Similarity'].mean():.3f}")
print(f"Mean cluster proportion: {comparison_mgb['Cluster_Proportion'].mean():.3f}")
# Show cases where they differ
differing = comparison_mgb[comparison_mgb['MGB_psi'] != comparison_mgb['MGB_cluster']]
if len(differing) > 0:
print(f"\nSignatures where psi and cluster matches differ ({len(differing)} cases):")
print(differing[['UKB', 'MGB_psi', 'MGB_cluster', 'Psi_Similarity', 'Cluster_Proportion']].to_string(index=False))
else:
if best_matches_ukb_mgb_psi is not None:
print("\n⚠ Cannot compare UKB ↔ MGB: Cluster-based matches not available (run cells 8-10 first)")
if best_matches_ukb_aou_psi is not None and has_cluster_aou:
print("\n" + "="*80)
print("COMPARISON: PSI-BASED vs CLUSTER-BASED MATCHES (UKB ↔ AoU)")
print("="*80)
comparison_aou = pd.merge(
best_matches_ukb_aou_psi[['UKB', 'AoU', 'Similarity']].rename(columns={'AoU': 'AoU_psi', 'Similarity': 'Psi_Similarity'}),
best_matches_aou[['UKB', 'AoU', 'Proportion']].rename(columns={'AoU': 'AoU_cluster', 'Proportion': 'Cluster_Proportion'}),
on='UKB',
how='inner'
)
matches_agree = (comparison_aou['AoU_psi'] == comparison_aou['AoU_cluster']).sum()
matches_total = len(comparison_aou)
agreement_rate = matches_agree / matches_total if matches_total > 0 else 0
print(f"\nAgreement between psi-based and cluster-based matches: {matches_agree}/{matches_total} ({agreement_rate*100:.1f}%)")
print(f"\nMean psi similarity: {comparison_aou['Psi_Similarity'].mean():.3f}")
print(f"Mean cluster proportion: {comparison_aou['Cluster_Proportion'].mean():.3f}")
differing = comparison_aou[comparison_aou['AoU_psi'] != comparison_aou['AoU_cluster']]
if len(differing) > 0:
print(f"\nSignatures where psi and cluster matches differ ({len(differing)} cases):")
print(differing[['UKB', 'AoU_psi', 'AoU_cluster', 'Psi_Similarity', 'Cluster_Proportion']].to_string(index=False))
else:
if best_matches_ukb_aou_psi is not None:
print("\n⚠ Cannot compare UKB ↔ AoU: Cluster-based matches not available (run cells 8-10 first)")
================================================================================ COMPARISON: PSI-BASED vs CLUSTER-BASED MATCHES (UKB ↔ MGB) ================================================================================ Agreement between psi-based and cluster-based matches: 7/20 (35.0%) Mean psi similarity: 0.985 Mean cluster proportion: 0.782 Signatures where psi and cluster matches differ (13 cases): UKB MGB_psi MGB_cluster Psi_Similarity Cluster_Proportion 0 20 5 0.988227 0.928571 2 20 7 0.988902 0.933333 3 20 12 0.918245 0.560976 5 20 5 0.995764 0.571429 7 20 1 0.986080 0.500000 10 20 16 0.989447 1.000000 12 20 3 0.988498 0.714286 13 20 18 0.979726 0.615385 14 20 14 0.992778 0.900000 15 20 6 0.992997 0.800000 16 20 4 0.983608 0.689655 17 20 8 0.986170 0.882353 19 20 15 0.972544 0.391304 ================================================================================ COMPARISON: PSI-BASED vs CLUSTER-BASED MATCHES (UKB ↔ AoU) ================================================================================ Agreement between psi-based and cluster-based matches: 6/20 (30.0%) Mean psi similarity: 0.986 Mean cluster proportion: 0.766 Signatures where psi and cluster matches differ (14 cases): UKB AoU_psi AoU_cluster Psi_Similarity Cluster_Proportion 0 20 16 0.987710 0.812500 1 17 6 0.983342 0.666667 2 20 12 0.989111 1.000000 3 18 5 0.921214 0.451220 4 20 3 0.992784 1.000000 5 20 16 0.995785 0.714286 7 20 1 0.986184 0.181818 10 20 8 0.989616 1.000000 13 2 5 0.982106 0.538462 14 20 19 0.992767 0.900000 15 20 14 0.992989 0.800000 16 20 1 0.983794 0.689655 17 20 12 0.986318 0.764706 19 0 13 0.973216 0.391304
def calculate_weighted_modified_jaccard(ukb_psi, other_psi, ukb_clusters, other_clusters,
ukb_disease_names, other_disease_names,
common_diseases, biobank_name):
"""
Calculate weighted modified Jaccard similarity using cluster assignments weighted by psi coefficients.
For each UKB cluster k:
1. Get set of diseases in UKB cluster k (from common diseases)
2. For each cluster in other biobank, calculate weighted intersection
3. Weight by psi coefficients: stronger associations contribute more
4. Find best-matching cluster (maximum weighted intersection / weighted UKB cluster size)
Returns:
--------
weighted_jaccard_scores : list
List of weighted Jaccard scores for each UKB cluster
cluster_details : dict
Dictionary with details for each cluster
"""
# Create disease name to index mapping
ukb_disease_to_idx = {d: i for i, d in enumerate(ukb_disease_names)}
other_disease_to_idx = {d: i for i, d in enumerate(other_disease_names)}
# Create mapping: common disease name -> (ukb_idx, other_idx)
disease_mapping = {}
for d in common_diseases:
if d in ukb_disease_to_idx and d in other_disease_to_idx:
disease_mapping[d] = (ukb_disease_to_idx[d], other_disease_to_idx[d])
weighted_jaccard_scores = []
cluster_details = {}
# Get unique UKB clusters
ukb_unique_clusters = sorted(set(ukb_clusters))
for ukb_sig in ukb_unique_clusters:
# Get diseases in UKB signature (only common diseases)
ukb_sig_diseases = []
for d, (ukb_idx, other_idx) in disease_mapping.items():
if ukb_clusters[ukb_idx] == ukb_sig:
ukb_sig_diseases.append((d, ukb_idx, other_idx))
if len(ukb_sig_diseases) == 0:
continue
# Calculate weighted size of UKB signature (sum of psi values)
ukb_weighted_size = sum(ukb_psi[ukb_sig, ukb_idx] for d, ukb_idx, other_idx in ukb_sig_diseases)
if ukb_weighted_size == 0:
continue
# Find best matching cluster in other biobank
best_match_score = 0
best_match_cluster = None
best_weighted_overlap = 0
# Get unique clusters in other biobank
other_unique_clusters = sorted(set(other_clusters))
for other_sig in other_unique_clusters:
# Calculate weighted overlap: sum of psi values for diseases in both signatures
weighted_overlap = 0
weighted_ukb_only = ukb_weighted_size # Start with UKB signature weighted size
for d, ukb_idx, other_idx in ukb_sig_diseases:
if other_clusters[other_idx] == other_sig:
# Disease is in overlap: add both psi values
weighted_overlap += (ukb_psi[ukb_sig, ukb_idx] + other_psi[other_sig, other_idx])
weighted_ukb_only -= ukb_psi[ukb_sig, ukb_idx] # Remove from UKB-only count
# Add diseases only in other signature
weighted_other_only = 0
for d, (ukb_idx, other_idx) in disease_mapping.items():
if other_clusters[other_idx] == other_sig and ukb_clusters[ukb_idx] != ukb_sig:
weighted_other_only += other_psi[other_sig, other_idx]
# Weighted union = weighted_overlap + weighted_ukb_only + weighted_other_only
weighted_union = weighted_overlap + weighted_ukb_only + weighted_other_only
# Weighted Jaccard: weighted_overlap / weighted_union
if weighted_union > 0:
weighted_jaccard = weighted_overlap / weighted_union
else:
weighted_jaccard = 0.0
# Use intersection over UKB size (like original modified Jaccard)
# Weighted overlap divided by 2*ukb_weighted_size (since overlap has both psi values)
weighted_jaccard_over_ukb = weighted_overlap / (2 * ukb_weighted_size) if ukb_weighted_size > 0 else 0.0
# Use the intersection/UKB version to match original modified Jaccard logic
if weighted_jaccard_over_ukb > best_match_score:
best_match_score = weighted_jaccard_over_ukb
best_match_cluster = other_sig
best_weighted_overlap = weighted_overlap
weighted_jaccard_scores.append(best_match_score)
cluster_details[ukb_sig] = {
'weighted_jaccard': best_match_score,
'best_match': best_match_cluster,
'weighted_overlap': best_weighted_overlap,
'ukb_weighted_size': ukb_weighted_size
}
return weighted_jaccard_scores, cluster_details
# Load clusters (you'll need to add this to your existing code)
print("\n" + "="*80)
print("LOADING CLUSTERS FOR WEIGHTED JACCARD")
print("="*80)
# Load UKB clusters and disease names
ukb_checkpoint_ref = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/model_with_kappa_bigam.pt', map_location='cpu')
ukb_clusters = ukb_checkpoint_ref['clusters']
ukb_disease_names = ukb_checkpoint_ref['disease_names']
if isinstance(ukb_clusters, torch.Tensor):
ukb_clusters = ukb_clusters.numpy()
if isinstance(ukb_disease_names, (list, tuple)):
ukb_disease_names = list(ukb_disease_names)
elif hasattr(ukb_disease_names, 'values'):
ukb_disease_names = ukb_disease_names.values.tolist()
# Load MGB clusters and disease names
mgb_checkpoint = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/model_with_kappa_bigam_MGB.pt', map_location='cpu')
mgb_clusters = mgb_checkpoint['clusters']
mgb_disease_names = mgb_checkpoint['disease_names']
if isinstance(mgb_clusters, torch.Tensor):
mgb_clusters = mgb_clusters.numpy()
if isinstance(mgb_disease_names, (list, tuple)):
mgb_disease_names = list(mgb_disease_names)
elif hasattr(mgb_disease_names, 'values'):
mgb_disease_names = mgb_disease_names.values.tolist()
# Load AoU clusters and disease names
aou_checkpoint = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/model_with_kappa_bigam_AOU.pt', map_location='cpu')
aou_clusters = aou_checkpoint['clusters']
aou_disease_names = aou_checkpoint['disease_names']
if isinstance(aou_clusters, torch.Tensor):
aou_clusters = aou_clusters.numpy()
if isinstance(aou_disease_names, (list, tuple)):
aou_disease_names = list(aou_disease_names)
elif hasattr(aou_disease_names, 'values'):
aou_disease_names = aou_disease_names.values.tolist()
# Find common diseases
common_ukb_mgb = list(set(ukb_disease_names) & set(mgb_disease_names))
common_ukb_aou = list(set(ukb_disease_names) & set(aou_disease_names))
print(f"\nUKB: {len(ukb_disease_names)} diseases, {ukb_clusters.max()+1} signatures")
print(f"MGB: {len(mgb_disease_names)} diseases, {mgb_clusters.max()+1} signatures")
print(f"AoU: {len(aou_disease_names)} diseases, {aou_clusters.max()+1} signatures")
print(f"\nCommon diseases (UKB ↔ MGB): {len(common_ukb_mgb)}")
print(f"Common diseases (UKB ↔ AoU): {len(common_ukb_aou)}")
# Calculate weighted Jaccard
print("\n" + "="*80)
print("CALCULATING WEIGHTED MODIFIED JACCARD")
print("="*80)
print("\nUKB ↔ MGB:")
weighted_jaccard_mgb, weighted_details_mgb = calculate_weighted_modified_jaccard(
ukb_psi, mgb_psi, ukb_clusters, mgb_clusters,
ukb_disease_names, mgb_disease_names, common_ukb_mgb, 'MGB'
)
print("\nUKB ↔ AoU:")
weighted_jaccard_aou, weighted_details_aou = calculate_weighted_modified_jaccard(
ukb_psi, aou_psi, ukb_clusters, aou_clusters,
ukb_disease_names, aou_disease_names, common_ukb_aou, 'AoU'
)
# Compare with binary Jaccard
print("\n" + "="*80)
print("COMPARING WEIGHTED vs BINARY JACCARD")
print("="*80)
print(f"\nWeighted Jaccard (UKB ↔ MGB): median = {np.median(weighted_jaccard_mgb):.3f}")
print(f"Weighted Jaccard (UKB ↔ AoU): median = {np.median(weighted_jaccard_aou):.3f}")
================================================================================ LOADING CLUSTERS FOR WEIGHTED JACCARD ================================================================================ UKB: 348 diseases, 20 signatures MGB: 346 diseases, 20 signatures AoU: 348 diseases, 20 signatures Common diseases (UKB ↔ MGB): 346 Common diseases (UKB ↔ AoU): 348 ================================================================================ CALCULATING WEIGHTED MODIFIED JACCARD ================================================================================ UKB ↔ MGB: UKB ↔ AoU: ================================================================================ COMPARING WEIGHTED vs BINARY JACCARD ================================================================================ Weighted Jaccard (UKB ↔ MGB): median = 1.295 Weighted Jaccard (UKB ↔ AoU): median = 0.954
# Compare weighted vs binary matches
print("\n" + "="*80)
print("COMPARING WEIGHTED vs BINARY MATCHES")
print("="*80)
# Load binary matches (you'll need to have details_mgb and details_aou from earlier)
# If you don't have them loaded, you'll need to re-run the binary calculation
# For now, let's assume they exist or create a comparison structure
# Create comparison DataFrames
def create_match_comparison(weighted_details, binary_details, cohort_name):
"""Create DataFrame comparing weighted and binary matches."""
comparison_data = []
# Get all UKB signatures that appear in both
all_sigs = sorted(set(list(weighted_details.keys()) + list(binary_details.keys())))
for ukb_sig in all_sigs:
weighted_match = weighted_details.get(ukb_sig, {}).get('best_match', None)
weighted_score = weighted_details.get(ukb_sig, {}).get('weighted_jaccard', 0)
binary_match = binary_details.get(ukb_sig, {}).get('best_match', None)
binary_score = binary_details.get(ukb_sig, {}).get('jaccard', 0)
match_agreement = "✓" if weighted_match == binary_match else "✗"
comparison_data.append({
'UKB_Sig': ukb_sig,
f'{cohort_name}_Binary': binary_match,
f'{cohort_name}_Weighted': weighted_match,
'Agreement': match_agreement,
'Binary_Score': binary_score,
'Weighted_Score': weighted_score,
'Score_Diff': weighted_score - binary_score
})
return pd.DataFrame(comparison_data)
# Assuming you have details_mgb and details_aou from the binary calculation
# If not, you'll need to load/calculate them first
print("\nUKB ↔ MGB Match Comparison:")
comparison_mgb = create_match_comparison(weighted_details_mgb, details_mgb, 'MGB')
print(comparison_mgb.to_string(index=False))
print(f"\nAgreement: {comparison_mgb['Agreement'].value_counts().to_dict()}")
print("\n" + "-"*80)
print("UKB ↔ AoU Match Comparison:")
comparison_aou = create_match_comparison(weighted_details_aou, details_aou, 'AoU')
print(comparison_aou.to_string(index=False))
print(f"\nAgreement: {comparison_aou['Agreement'].value_counts().to_dict()}")
# Show signatures where matches differ
print("\n" + "="*80)
print("SIGNATURES WHERE MATCHES DIFFER")
print("="*80)
differing_mgb = comparison_mgb[comparison_mgb['Agreement'] == '✗']
if len(differing_mgb) > 0:
print(f"\nUKB ↔ MGB: {len(differing_mgb)} signatures with different matches:")
print(differing_mgb[['UKB_Sig', 'MGB_Binary', 'MGB_Weighted', 'Binary_Score', 'Weighted_Score']].to_string(index=False))
else:
print("\nUKB ↔ MGB: All matches agree! ✓")
differing_aou = comparison_aou[comparison_aou['Agreement'] == '✗']
if len(differing_aou) > 0:
print(f"\nUKB ↔ AoU: {len(differing_aou)} signatures with different matches:")
print(differing_aou[['UKB_Sig', 'AoU_Binary', 'AoU_Weighted', 'Binary_Score', 'Weighted_Score']].to_string(index=False))
else:
print("\nUKB ↔ AoU: All matches agree! ✓")
# Check if any weighted matches go to signature 20 (healthy signature)
print("\n" + "="*80)
print("CHECKING FOR SIGNATURE 20 (HEALTHY) MATCHES")
print("="*80)
mgb_to_healthy = comparison_mgb[comparison_mgb['MGB_Weighted'] == 20]
aou_to_healthy = comparison_aou[comparison_aou['AoU_Weighted'] == 20]
print(f"\nUKB ↔ MGB: {len(mgb_to_healthy)} signatures matched to MGB Signature 20 (healthy)")
if len(mgb_to_healthy) > 0:
print(mgb_to_healthy[['UKB_Sig', 'MGB_Binary', 'MGB_Weighted']].to_string(index=False))
print(f"\nUKB ↔ AoU: {len(aou_to_healthy)} signatures matched to AoU Signature 20 (healthy)")
if len(aou_to_healthy) > 0:
print(aou_to_healthy[['UKB_Sig', 'AoU_Binary', 'AoU_Weighted']].to_string(index=False))
# Show example diseases from a few matched signatures to verify biological coherence
print("\n" + "="*80)
print("EXAMPLE DISEASES FROM MATCHED SIGNATURES (to verify biological coherence)")
print("="*80)
def show_top_diseases_in_signature(psi, clusters, disease_names, sig_idx, top_n=5):
"""Show top N diseases in a signature based on psi values."""
sig_disease_indices = np.where(clusters == sig_idx)[0]
if len(sig_disease_indices) == 0:
return []
# Get psi values for diseases in this signature
psi_values = [(i, psi[sig_idx, i]) for i in sig_disease_indices]
# Sort by psi (descending)
psi_values.sort(key=lambda x: x[1], reverse=True)
# Get top N
top_indices = [i for i, _ in psi_values[:top_n]]
top_diseases = [disease_names[i] for i in top_indices if i < len(disease_names)]
return top_diseases
# Show a few examples
example_sigs = [5, 6, 11] # Cardiovascular, Cancer, Cerebrovascular
for ukb_sig in example_sigs:
if ukb_sig in weighted_details_mgb:
mgb_match = weighted_details_mgb[ukb_sig]['best_match']
print(f"\nUKB Signature {ukb_sig} → MGB Signature {mgb_match}:")
ukb_top = show_top_diseases_in_signature(ukb_psi, ukb_clusters, ukb_disease_names, ukb_sig, top_n=5)
mgb_top = show_top_diseases_in_signature(mgb_psi, mgb_clusters, mgb_disease_names, mgb_match, top_n=5)
print(f" UKB top diseases: {', '.join(ukb_top[:3])}")
print(f" MGB top diseases: {', '.join(mgb_top[:3])}")
for ukb_sig in example_sigs:
if ukb_sig in weighted_details_aou:
aou_match = weighted_details_aou[ukb_sig]['best_match']
print(f"\nUKB Signature {ukb_sig} → AoU Signature {aou_match}:")
ukb_top = show_top_diseases_in_signature(ukb_psi, ukb_clusters, ukb_disease_names, ukb_sig, top_n=5)
aou_top = show_top_diseases_in_signature(aou_psi, aou_clusters, aou_disease_names, aou_match, top_n=5)
print(f" UKB top diseases: {', '.join(ukb_top[:3])}")
print(f" AoU top diseases: {', '.join(aou_top[:3])}")
================================================================================
COMPARING WEIGHTED vs BINARY MATCHES
================================================================================
UKB ↔ MGB Match Comparison:
UKB_Sig MGB_Binary MGB_Weighted Agreement Binary_Score Weighted_Score Score_Diff
0 5 5.0 ✓ 0.928571 38.935484 38.006912
1 2 2.0 ✓ 0.952381 2.204099 1.251718
2 7 NaN ✗ 0.933333 0.000000 -0.933333
3 12 12.0 ✓ 0.560976 1.917849 1.356873
4 0 0.0 ✓ 1.000000 1.683012 0.683012
5 5 NaN ✗ 0.571429 0.000000 -0.571429
6 11 11.0 ✓ 0.875000 1.130759 0.255759
7 1 NaN ✗ 0.500000 0.000000 -0.500000
8 19 19.0 ✓ 0.678571 1.617055 0.938483
9 9 9.0 ✓ 1.000000 2.346589 1.346589
10 16 16.0 ✓ 1.000000 7.786339 6.786339
11 10 10.0 ✓ 0.875000 1.947143 1.072143
12 3 3.0 ✓ 0.714286 1.067492 0.353206
13 18 18.0 ✓ 0.615385 0.838097 0.222713
14 14 NaN ✗ 0.900000 0.000000 -0.900000
15 6 6.0 ✓ 0.800000 1.638514 0.838514
16 4 NaN ✗ 0.689655 0.000000 -0.689655
17 8 NaN ✗ 0.882353 0.000000 -0.882353
18 13 13.0 ✓ 0.777778 1.459974 0.682196
19 15 15.0 ✓ 0.391304 0.679451 0.288147
Agreement: {'✓': 14, '✗': 6}
--------------------------------------------------------------------------------
UKB ↔ AoU Match Comparison:
UKB_Sig AoU_Binary AoU_Weighted Agreement Binary_Score Weighted_Score Score_Diff
0 16 NaN ✗ 0.812500 0.000000 -0.812500
1 6 6.0 ✓ 0.666667 1.367644 0.700978
2 12 NaN ✗ 1.000000 0.000000 -1.000000
3 5 5.0 ✓ 0.451220 1.408293 0.957074
4 3 3.0 ✓ 1.000000 0.914335 -0.085665
5 16 NaN ✗ 0.714286 0.000000 -0.714286
6 11 11.0 ✓ 1.000000 1.155166 0.155166
7 1 NaN ✗ 0.181818 0.000000 -0.181818
8 7 7.0 ✓ 0.750000 1.298151 0.548151
9 9 9.0 ✓ 1.000000 1.942268 0.942268
10 8 8.0 ✓ 1.000000 10.131687 9.131687
11 4 4.0 ✓ 1.000000 1.537151 0.537151
12 15 15.0 ✓ 1.000000 1.582985 0.582985
13 5 2.0 ✗ 0.538462 0.812027 0.273566
14 19 NaN ✗ 0.900000 0.000000 -0.900000
15 14 14.0 ✓ 0.800000 0.993149 0.193149
16 1 NaN ✗ 0.689655 0.000000 -0.689655
17 12 NaN ✗ 0.764706 0.000000 -0.764706
18 10 10.0 ✓ 0.666667 1.088136 0.421469
19 13 13.0 ✓ 0.391304 0.637314 0.246009
Agreement: {'✓': 12, '✗': 8}
================================================================================
SIGNATURES WHERE MATCHES DIFFER
================================================================================
UKB ↔ MGB: 6 signatures with different matches:
UKB_Sig MGB_Binary MGB_Weighted Binary_Score Weighted_Score
2 7 NaN 0.933333 0.0
5 5 NaN 0.571429 0.0
7 1 NaN 0.500000 0.0
14 14 NaN 0.900000 0.0
16 4 NaN 0.689655 0.0
17 8 NaN 0.882353 0.0
UKB ↔ AoU: 8 signatures with different matches:
UKB_Sig AoU_Binary AoU_Weighted Binary_Score Weighted_Score
0 16 NaN 0.812500 0.000000
2 12 NaN 1.000000 0.000000
5 16 NaN 0.714286 0.000000
7 1 NaN 0.181818 0.000000
13 5 2.0 0.538462 0.812027
14 19 NaN 0.900000 0.000000
16 1 NaN 0.689655 0.000000
17 12 NaN 0.764706 0.000000
================================================================================
CHECKING FOR SIGNATURE 20 (HEALTHY) MATCHES
================================================================================
UKB ↔ MGB: 0 signatures matched to MGB Signature 20 (healthy)
UKB ↔ AoU: 0 signatures matched to AoU Signature 20 (healthy)
================================================================================
EXAMPLE DISEASES FROM MATCHED SIGNATURES (to verify biological coherence)
================================================================================
UKB Signature 5 → MGB Signature None:
UKB top diseases: Other acute and subacute forms of ischemic heart disease, Unstable angina (intermediate coronary syndrome), Coronary atherosclerosis
MGB top diseases:
UKB Signature 6 → MGB Signature 11:
UKB top diseases: Secondary malignant neoplasm of liver, Secondary malignancy of bone, Secondary malignancy of respiratory organs
MGB top diseases: Malignant neoplasm, other, Secondary malignancy of respiratory organs, Secondary malignancy of bone
UKB Signature 11 → MGB Signature 10:
UKB top diseases: Cerebral artery occlusion, with cerebral infarction, Occlusion and stenosis of precerebral arteries, Late effects of cerebrovascular disease
MGB top diseases: Cerebrovascular disease, Late effects of cerebrovascular disease, Occlusion of cerebral arteries
UKB Signature 5 → AoU Signature None:
UKB top diseases: Other acute and subacute forms of ischemic heart disease, Unstable angina (intermediate coronary syndrome), Coronary atherosclerosis
AoU top diseases:
UKB Signature 6 → AoU Signature 11:
UKB top diseases: Secondary malignant neoplasm of liver, Secondary malignancy of bone, Secondary malignancy of respiratory organs
AoU top diseases: Secondary malignant neoplasm, Secondary malignancy of lymph nodes, Secondary malignancy of respiratory organs
UKB Signature 11 → AoU Signature 4:
UKB top diseases: Cerebral artery occlusion, with cerebral infarction, Occlusion and stenosis of precerebral arteries, Late effects of cerebrovascular disease
AoU top diseases: Cerebrovascular disease, Transient cerebral ischemia, Cerebral ischemia
# Re-run weighted Jaccard with corrected MGB model
print("\n" + "="*80)
print("RE-RUNNING WEIGHTED JACCARD WITH CORRECTED MGB MODEL")
print("="*80)
# Make sure we're using the correct MGB psi
print("\nVerifying MGB psi source:")
mgb_checkpoint_verify = torch.load('/Users/sarahurbut/aladynoulli2/mgb_model_initialized.pt', map_location='cpu')
if 'model_state_dict' in mgb_checkpoint_verify and 'psi' in mgb_checkpoint_verify['model_state_dict']:
mgb_psi_verify = mgb_checkpoint_verify['model_state_dict']['psi']
if torch.is_tensor(mgb_psi_verify):
mgb_psi_verify = mgb_psi_verify.detach().cpu().numpy()
print(f" MGB psi range: [{mgb_psi_verify.min():.3f}, {mgb_psi_verify.max():.3f}]")
print(f" MGB psi mean: {mgb_psi_verify.mean():.3f}")
print(f" ✓ Using mgb_model_initialized.pt (trained with corrected E/prevalence)")
# Re-calculate weighted Jaccard
print("\nUKB ↔ MGB:")
weighted_jaccard_mgb_new, weighted_details_mgb_new = calculate_weighted_modified_jaccard(
ukb_psi, mgb_psi, ukb_clusters, mgb_clusters,
ukb_disease_names, mgb_disease_names, common_ukb_mgb, 'MGB'
)
print("\nUKB ↔ AoU:")
weighted_jaccard_aou_new, weighted_details_aou_new = calculate_weighted_modified_jaccard(
ukb_psi, aou_psi, ukb_clusters, aou_clusters,
ukb_disease_names, aou_disease_names, common_ukb_aou, 'AoU'
)
# Compare results
print("\n" + "="*80)
print("WEIGHTED JACCARD RESULTS (with corrected MGB)")
print("="*80)
print(f"\nWeighted Jaccard (UKB ↔ MGB): median = {np.median(weighted_jaccard_mgb_new):.3f}")
print(f" Range: [{np.min(weighted_jaccard_mgb_new):.3f}, {np.max(weighted_jaccard_mgb_new):.3f}]")
print(f" IQR: [{np.percentile(weighted_jaccard_mgb_new, 25):.3f}, {np.percentile(weighted_jaccard_mgb_new, 75):.3f}]")
print(f"\nWeighted Jaccard (UKB ↔ AoU): median = {np.median(weighted_jaccard_aou_new):.3f}")
print(f" Range: [{np.min(weighted_jaccard_aou_new):.3f}, {np.max(weighted_jaccard_aou_new):.3f}]")
print(f" IQR: [{np.percentile(weighted_jaccard_aou_new, 25):.3f}, {np.percentile(weighted_jaccard_aou_new, 75):.3f}]")
# Compare with binary
print(f"\nBinary Jaccard (UKB ↔ MGB): median = {np.median(jaccard_mgb):.3f}")
print(f"Binary Jaccard (UKB ↔ AoU): median = {np.median(jaccard_aou):.3f}")
# Check how many are > 1.0 (indicating the formula bug)
mgb_over_one = np.sum(np.array(weighted_jaccard_mgb_new) > 1.0)
aou_over_one = np.sum(np.array(weighted_jaccard_aou_new) > 1.0)
print(f"\n⚠ Weighted scores > 1.0 (formula issue):")
print(f" UKB ↔ MGB: {mgb_over_one}/{len(weighted_jaccard_mgb_new)} signatures")
print(f" UKB ↔ AoU: {aou_over_one}/{len(weighted_jaccard_aou_new)} signatures")
================================================================================ RE-RUNNING WEIGHTED JACCARD WITH CORRECTED MGB MODEL ================================================================================ Verifying MGB psi source: MGB psi range: [-5.188, 3.483] MGB psi mean: -2.039 ✓ Using mgb_model_initialized.pt (trained with corrected E/prevalence) UKB ↔ MGB: UKB ↔ AoU: ================================================================================ WEIGHTED JACCARD RESULTS (with corrected MGB) ================================================================================ Weighted Jaccard (UKB ↔ MGB): median = 1.295 Range: [0.000, 38.935] IQR: [0.000, 1.925] Weighted Jaccard (UKB ↔ AoU): median = 0.954 Range: [0.000, 10.132] IQR: [0.000, 1.378] Binary Jaccard (UKB ↔ MGB): median = 0.838 Binary Jaccard (UKB ↔ AoU): median = 0.782 ⚠ Weighted scores > 1.0 (formula issue): UKB ↔ MGB: 12/20 signatures UKB ↔ AoU: 9/20 signatures
# Inspect diseases in signatures with different matches
print("\n" + "="*80)
print("DISEASES IN SIGNATURES WITH DIFFERENT MATCHES")
print("="*80)
def get_diseases_in_signature(clusters, disease_names, sig_idx, top_n=10):
"""Get diseases in a signature, sorted by cluster assignment."""
sig_disease_indices = np.where(clusters == sig_idx)[0]
diseases = [disease_names[i] for i in sig_disease_indices if i < len(disease_names)]
return diseases[:top_n]
def get_diseases_with_psi(psi, clusters, disease_names, sig_idx, top_n=10):
"""Get top diseases in a signature based on psi values."""
sig_disease_indices = np.where(clusters == sig_idx)[0]
if len(sig_disease_indices) == 0:
return []
# Get psi values and sort
psi_values = [(i, psi[sig_idx, i]) for i in sig_disease_indices]
psi_values.sort(key=lambda x: x[1], reverse=True)
# Get top N
top_indices = [i for i, _ in psi_values[:top_n]]
top_diseases = [(disease_names[i], psi[sig_idx, i]) for i in top_indices if i < len(disease_names)]
return top_diseases
# Check UKB ↔ MGB differing matches
print("\n" + "-"*80)
print("UKB ↔ MGB: Inspecting Differing Matches")
print("-"*80)
for idx, row in differing_mgb.iterrows():
ukb_sig = int(row['UKB_Sig'])
mgb_binary = int(row['MGB_Binary']) if not pd.isna(row['MGB_Binary']) else None
mgb_weighted = int(row['MGB_Weighted']) if not pd.isna(row['MGB_Weighted']) else None
print(f"\n{'='*80}")
print(f"UKB Signature {ukb_sig}")
print(f" Binary match: MGB Signature {mgb_binary} (score: {row['Binary_Score']:.3f})")
print(f" Weighted match: MGB Signature {mgb_weighted} (score: {row['Weighted_Score']:.3f})")
print(f"{'='*80}")
# Get top diseases in UKB signature (by psi)
ukb_top = get_diseases_with_psi(ukb_psi, ukb_clusters, ukb_disease_names, ukb_sig, top_n=8)
print(f"\nUKB Signature {ukb_sig} - Top diseases (by psi):")
for i, (disease, psi_val) in enumerate(ukb_top, 1):
print(f" {i}. {disease} (psi: {psi_val:.3f})")
# Get diseases in binary-matched MGB signature
if mgb_binary is not None:
mgb_binary_top = get_diseases_with_psi(mgb_psi, mgb_clusters, mgb_disease_names, mgb_binary, top_n=8)
print(f"\nMGB Signature {mgb_binary} (BINARY MATCH) - Top diseases (by psi):")
for i, (disease, psi_val) in enumerate(mgb_binary_top, 1):
print(f" {i}. {disease} (psi: {psi_val:.3f})")
# Get diseases in weighted-matched MGB signature
if mgb_weighted is not None:
mgb_weighted_top = get_diseases_with_psi(mgb_psi, mgb_clusters, mgb_disease_names, mgb_weighted, top_n=8)
print(f"\nMGB Signature {mgb_weighted} (WEIGHTED MATCH) - Top diseases (by psi):")
for i, (disease, psi_val) in enumerate(mgb_weighted_top, 1):
print(f" {i}. {disease} (psi: {psi_val:.3f})")
else:
print(f"\nMGB Signature (WEIGHTED MATCH): NaN - No match found")
# Check UKB ↔ AoU differing matches
print("\n\n" + "-"*80)
print("UKB ↔ AoU: Inspecting Differing Matches")
print("-"*80)
for idx, row in differing_aou.iterrows():
ukb_sig = int(row['UKB_Sig'])
aou_binary = int(row['AoU_Binary']) if not pd.isna(row['AoU_Binary']) else None
aou_weighted = int(row['AoU_Weighted']) if not pd.isna(row['AoU_Weighted']) else None
print(f"\n{'='*80}")
print(f"UKB Signature {ukb_sig}")
print(f" Binary match: AoU Signature {aou_binary} (score: {row['Binary_Score']:.3f})")
print(f" Weighted match: AoU Signature {aou_weighted} (score: {row['Weighted_Score']:.3f})")
print(f"{'='*80}")
# Get top diseases in UKB signature (by psi)
ukb_top = get_diseases_with_psi(ukb_psi, ukb_clusters, ukb_disease_names, ukb_sig, top_n=8)
print(f"\nUKB Signature {ukb_sig} - Top diseases (by psi):")
for i, (disease, psi_val) in enumerate(ukb_top, 1):
print(f" {i}. {disease} (psi: {psi_val:.3f})")
# Get diseases in binary-matched AoU signature
if aou_binary is not None:
aou_binary_top = get_diseases_with_psi(aou_psi, aou_clusters, aou_disease_names, aou_binary, top_n=8)
print(f"\nAoU Signature {aou_binary} (BINARY MATCH) - Top diseases (by psi):")
for i, (disease, psi_val) in enumerate(aou_binary_top, 1):
print(f" {i}. {disease} (psi: {psi_val:.3f})")
# Get diseases in weighted-matched AoU signature
if aou_weighted is not None:
aou_weighted_top = get_diseases_with_psi(aou_psi, aou_clusters, aou_disease_names, aou_weighted, top_n=8)
print(f"\nAoU Signature {aou_weighted} (WEIGHTED MATCH) - Top diseases (by psi):")
for i, (disease, psi_val) in enumerate(aou_weighted_top, 1):
print(f" {i}. {disease} (psi: {psi_val:.3f})")
else:
print(f"\nAoU Signature (WEIGHTED MATCH): NaN - No match found")
# Also check why some signatures get NaN/0.000000
print("\n\n" + "="*80)
print("INVESTIGATING NaN/0.000000 WEIGHTED SCORES")
print("="*80)
print("\nFor signatures with NaN weighted matches, let's check if they have any diseases in common:")
for idx, row in differing_mgb.iterrows():
if pd.isna(row['MGB_Weighted']) or row['Weighted_Score'] == 0.0:
ukb_sig = int(row['UKB_Sig'])
mgb_binary = int(row['MGB_Binary']) if not pd.isna(row['MGB_Binary']) else None
# Get diseases in UKB signature
ukb_disease_indices = np.where(ukb_clusters == ukb_sig)[0]
ukb_diseases_set = set([ukb_disease_names[i] for i in ukb_disease_indices if i < len(ukb_disease_names)])
# Get diseases in MGB binary match
if mgb_binary is not None:
mgb_disease_indices = np.where(mgb_clusters == mgb_binary)[0]
mgb_diseases_set = set([mgb_disease_names[i] for i in mgb_disease_indices if i < len(mgb_disease_names)])
# Find common diseases
common = ukb_diseases_set & mgb_diseases_set
# Check psi values for common diseases
if len(common) > 0:
print(f"\nUKB Sig {ukb_sig} ↔ MGB Sig {mgb_binary} (binary match):")
print(f" Common diseases: {len(common)}")
print(f" Sample common diseases: {list(common)[:5]}")
# Check if common diseases have non-zero psi in both
ukb_common_psi = []
mgb_common_psi = []
for d in list(common)[:5]:
if d in ukb_disease_names and d in mgb_disease_names:
ukb_idx = ukb_disease_names.index(d)
mgb_idx = mgb_disease_names.index(d)
if ukb_clusters[ukb_idx] == ukb_sig and mgb_clusters[mgb_idx] == mgb_binary:
ukb_psi_val = ukb_psi[ukb_sig, ukb_idx]
mgb_psi_val = mgb_psi[mgb_binary, mgb_idx]
ukb_common_psi.append(ukb_psi_val)
mgb_common_psi.append(mgb_psi_val)
print(f" {d}: UKB psi={ukb_psi_val:.3f}, MGB psi={mgb_psi_val:.3f}")
================================================================================
DISEASES IN SIGNATURES WITH DIFFERENT MATCHES
================================================================================
--------------------------------------------------------------------------------
UKB ↔ MGB: Inspecting Differing Matches
--------------------------------------------------------------------------------
================================================================================
UKB Signature 0
Binary match: MGB Signature 5 (score: 0.929)
Weighted match: MGB Signature 14 (score: 3.900)
================================================================================
UKB Signature 0 - Top diseases (by psi):
1. Paroxysmal ventricular tachycardia (psi: 0.682)
2. Pericarditis (psi: 0.629)
3. Primary/intrinsic cardiomyopathies (psi: 0.340)
4. Paroxysmal supraventricular tachycardia (psi: 0.334)
5. Left bundle branch block (psi: 0.215)
6. Rheumatic disease of the heart valves (psi: 0.108)
7. Heart failure NOS (psi: 0.056)
8. Cardiomegaly (psi: 0.007)
MGB Signature 5 (BINARY MATCH) - Top diseases (by psi):
1. Coronary atherosclerosis (psi: 0.839)
2. Congenital anomalies of great vessels (psi: 0.738)
3. Arrhythmia (cardiac) NOS (psi: 0.731)
4. Mitral valve disease (psi: 0.713)
5. Unstable angina (intermediate coronary syndrome) (psi: 0.588)
6. Pericarditis (psi: 0.577)
7. Rheumatic disease of the heart valves (psi: 0.157)
8. Right bundle branch block (psi: -0.095)
MGB Signature 14 (WEIGHTED MATCH) - Top diseases (by psi):
1. Pulmonary collapse; interstitial and compensatory emphysema (psi: 3.212)
2. Cancer of bronchus; lung (psi: 2.941)
3. Pleurisy; pleural effusion (psi: 2.705)
4. Empyema and pneumothorax (psi: 2.640)
5. Bacterial pneumonia (psi: 1.837)
6. Pneumonia (psi: 1.570)
7. Pneumococcal pneumonia (psi: 1.300)
8. Emphysema (psi: 0.677)
================================================================================
UKB Signature 1
Binary match: MGB Signature 2 (score: 0.952)
Weighted match: MGB Signature 12 (score: 0.129)
================================================================================
UKB Signature 1 - Top diseases (by psi):
1. Hallux rigidus (psi: 1.095)
2. Peripheral enthesopathies and allied syndromes (psi: 1.070)
3. Other disorders of synovium, tendon, and bursa (psi: 0.973)
4. Hammer toe (acquired) (psi: 0.952)
5. Acquired toe deformities (psi: 0.918)
6. Osteoarthrosis NOS (psi: 0.786)
7. Hallux valgus (Bunion) (psi: 0.770)
8. Joint effusions (psi: 0.697)
MGB Signature 2 (BINARY MATCH) - Top diseases (by psi):
1. Hallux rigidus (psi: 1.254)
2. Bursitis (psi: 1.252)
3. Rheumatoid arthritis (psi: 0.895)
4. Osteoarthritis; localized (psi: 0.588)
5. Other peripheral nerve disorders (psi: 0.267)
6. Hallux valgus (Bunion) (psi: 0.167)
7. Osteoarthrosis NOS (psi: 0.119)
8. Osteoarthrosis, localized, primary (psi: -0.065)
MGB Signature 12 (WEIGHTED MATCH) - Top diseases (by psi):
1. Chronic tonsillitis and adenoiditis (psi: 2.311)
2. Epilepsy, recurrent seizures, convulsions (psi: 2.227)
3. Acute pancreatitis (psi: 2.219)
4. Celiac disease (psi: 1.976)
5. Polymyalgia Rheumatica (psi: 1.924)
6. Ptosis of eyelid (psi: 1.854)
7. Redundant prepuce and phimosis/BXO (psi: 1.604)
8. Facial nerve disorders [CN7] (psi: 1.513)
================================================================================
UKB Signature 2
Binary match: MGB Signature 7 (score: 0.933)
Weighted match: MGB Signature None (score: 0.000)
================================================================================
UKB Signature 2 - Top diseases (by psi):
1. Ulcer of esophagus (psi: 0.661)
2. Stricture and stenosis of esophagus (psi: 0.617)
3. Gastric ulcer (psi: 0.298)
4. Duodenal ulcer (psi: 0.258)
5. Barrett's esophagus (psi: 0.210)
6. Other disorders of stomach and duodenum (psi: 0.157)
7. Esophageal bleeding (varices/hemorrhage) (psi: 0.036)
8. Duodenitis (psi: 0.028)
MGB Signature 7 (BINARY MATCH) - Top diseases (by psi):
1. Other specified gastritis (psi: 3.156)
2. Gastritis and duodenitis (psi: 2.732)
3. Stricture and stenosis of esophagus (psi: 2.603)
4. Esophagitis, GERD and related diseases (psi: 2.467)
5. Esophageal bleeding (varices/hemorrhage) (psi: 2.462)
6. Ulcer of esophagus (psi: 2.317)
7. Barrett's esophagus (psi: 1.995)
8. Other disorders of stomach and duodenum (psi: 1.726)
MGB Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 5
Binary match: MGB Signature 5 (score: 0.571)
Weighted match: MGB Signature None (score: 0.000)
================================================================================
UKB Signature 5 - Top diseases (by psi):
1. Other acute and subacute forms of ischemic heart disease (psi: 0.190)
2. Unstable angina (intermediate coronary syndrome) (psi: 0.150)
3. Coronary atherosclerosis (psi: -0.056)
4. Hypercholesterolemia (psi: -0.329)
5. Angina pectoris (psi: -0.948)
6. Other chronic ischemic heart disease, unspecified (psi: -1.043)
7. Myocardial infarction (psi: -1.183)
MGB Signature 5 (BINARY MATCH) - Top diseases (by psi):
1. Coronary atherosclerosis (psi: 0.839)
2. Congenital anomalies of great vessels (psi: 0.738)
3. Arrhythmia (cardiac) NOS (psi: 0.731)
4. Mitral valve disease (psi: 0.713)
5. Unstable angina (intermediate coronary syndrome) (psi: 0.588)
6. Pericarditis (psi: 0.577)
7. Rheumatic disease of the heart valves (psi: 0.157)
8. Right bundle branch block (psi: -0.095)
MGB Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 7
Binary match: MGB Signature 1 (score: 0.500)
Weighted match: MGB Signature None (score: 0.000)
================================================================================
UKB Signature 7 - Top diseases (by psi):
1. Rheumatism, unspecified and fibrositis (psi: 1.213)
2. Myalgia and myositis unspecified (psi: 0.221)
3. Cervicalgia (psi: 0.161)
4. Sleep apnea (psi: -0.010)
5. Abdominal pain (psi: -0.106)
6. Hyperlipidemia (psi: -0.137)
7. Other inflammatory spondylopathies (psi: -0.190)
8. Migraine (psi: -0.283)
MGB Signature 1 (BINARY MATCH) - Top diseases (by psi):
1. Chronic periodontitis (psi: 0.844)
2. Asthma (psi: 0.087)
3. Irritable Bowel Syndrome (psi: 0.017)
4. Alcoholic liver damage (psi: -0.086)
5. Viral infection (psi: -0.187)
6. Acute upper respiratory infections of multiple or unspecified sites (psi: -0.205)
7. Hypopotassemia (psi: -0.221)
8. Anxiety disorder (psi: -0.377)
MGB Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 14
Binary match: MGB Signature 14 (score: 0.900)
Weighted match: MGB Signature None (score: 0.000)
================================================================================
UKB Signature 14 - Top diseases (by psi):
1. Empyema and pneumothorax (psi: 0.747)
2. Emphysema (psi: 0.118)
3. Bronchiectasis (psi: 0.049)
4. Postinflammatory pulmonary fibrosis (psi: 0.012)
5. Chronic airway obstruction (psi: -0.085)
6. Pneumonia (psi: -0.193)
7. Pneumococcal pneumonia (psi: -0.214)
8. Obstructive chronic bronchitis (psi: -0.310)
MGB Signature 14 (BINARY MATCH) - Top diseases (by psi):
1. Pulmonary collapse; interstitial and compensatory emphysema (psi: 3.212)
2. Cancer of bronchus; lung (psi: 2.941)
3. Pleurisy; pleural effusion (psi: 2.705)
4. Empyema and pneumothorax (psi: 2.640)
5. Bacterial pneumonia (psi: 1.837)
6. Pneumonia (psi: 1.570)
7. Pneumococcal pneumonia (psi: 1.300)
8. Emphysema (psi: 0.677)
MGB Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 16
Binary match: MGB Signature 4 (score: 0.690)
Weighted match: MGB Signature None (score: 0.000)
================================================================================
UKB Signature 16 - Top diseases (by psi):
1. Peritonitis and retroperitoneal infections (psi: 0.317)
2. Bacterial pneumonia (psi: 0.031)
3. Neutropenia (psi: -0.113)
4. Chronic ulcer of skin (psi: -0.134)
5. Other local infections of skin and subcutaneous tissue (psi: -0.172)
6. Hyperpotassemia (psi: -0.275)
7. Gram negative septicemia (psi: -0.286)
8. E. coli (psi: -0.325)
MGB Signature 4 (BINARY MATCH) - Top diseases (by psi):
1. E. coli (psi: 1.576)
2. Carbuncle and furuncle (psi: 1.497)
3. Hypovolemia (psi: 1.288)
4. Orthostatic hypotension (psi: 1.201)
5. Iron deficiency anemias, unspecified or not due to blood loss (psi: 1.045)
6. Hyposmolality and/or hyponatremia (psi: 0.809)
7. Sepsis (psi: 0.664)
8. Other local infections of skin and subcutaneous tissue (psi: 0.644)
MGB Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 17
Binary match: MGB Signature 8 (score: 0.882)
Weighted match: MGB Signature None (score: 0.000)
================================================================================
UKB Signature 17 - Top diseases (by psi):
1. Regional enteritis (psi: 0.907)
2. Gastrointestinal complications (psi: 0.709)
3. Neoplasm of unspecified nature of digestive system (psi: 0.658)
4. Hemorrhage of rectum and anus (psi: 0.593)
5. Ventral hernia (psi: 0.508)
6. Ulcerative colitis (psi: 0.330)
7. Malignant neoplasm of rectum, rectosigmoid junction, and anus (psi: 0.146)
8. Colon cancer (psi: 0.075)
MGB Signature 8 (BINARY MATCH) - Top diseases (by psi):
1. Anal and rectal polyp (psi: 2.321)
2. Ulcerative colitis (psi: 1.702)
3. Regional enteritis (psi: 1.677)
4. Malignant neoplasm of rectum, rectosigmoid junction, and anus (psi: 1.190)
5. Peritoneal adhesions (postoperative) (postinfection) (psi: 1.188)
6. Colon cancer (psi: 0.768)
7. Peritonitis and retroperitoneal infections (psi: 0.444)
8. Benign neoplasm of colon (psi: 0.422)
MGB Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 19
Binary match: MGB Signature 15 (score: 0.391)
Weighted match: MGB Signature 12 (score: 0.389)
================================================================================
UKB Signature 19 - Top diseases (by psi):
1. Hypertensive chronic kidney disease (psi: 1.829)
2. Chronic glomerulonephritis, NOS (psi: 1.735)
3. Disorders of lacrimal system (psi: 1.095)
4. Breast cancer [female] (psi: 1.016)
5. Renal failure NOS (psi: 0.969)
6. Viral warts & HPV (psi: 0.896)
7. Scar conditions and fibrosis of skin (psi: 0.767)
8. Epiphora (psi: 0.601)
MGB Signature 15 (BINARY MATCH) - Top diseases (by psi):
1. Sebaceous cyst (psi: 1.191)
2. Benign neoplasm of skin (psi: 0.476)
3. Disorder of skin and subcutaneous tissue NOS (psi: 0.015)
4. Other non-epithelial cancer of skin (psi: 0.007)
5. Atopic/contact dermatitis due to other or unspecified (psi: -0.376)
6. Hemangioma and lymphangioma, any site (psi: -0.462)
7. Melanomas of skin (psi: -0.463)
8. Viral warts & HPV (psi: -0.526)
MGB Signature 12 (WEIGHTED MATCH) - Top diseases (by psi):
1. Chronic tonsillitis and adenoiditis (psi: 2.311)
2. Epilepsy, recurrent seizures, convulsions (psi: 2.227)
3. Acute pancreatitis (psi: 2.219)
4. Celiac disease (psi: 1.976)
5. Polymyalgia Rheumatica (psi: 1.924)
6. Ptosis of eyelid (psi: 1.854)
7. Redundant prepuce and phimosis/BXO (psi: 1.604)
8. Facial nerve disorders [CN7] (psi: 1.513)
--------------------------------------------------------------------------------
UKB ↔ AoU: Inspecting Differing Matches
--------------------------------------------------------------------------------
================================================================================
UKB Signature 0
Binary match: AoU Signature 16 (score: 0.812)
Weighted match: AoU Signature None (score: 0.000)
================================================================================
UKB Signature 0 - Top diseases (by psi):
1. Paroxysmal ventricular tachycardia (psi: 0.682)
2. Pericarditis (psi: 0.629)
3. Primary/intrinsic cardiomyopathies (psi: 0.340)
4. Paroxysmal supraventricular tachycardia (psi: 0.334)
5. Left bundle branch block (psi: 0.215)
6. Rheumatic disease of the heart valves (psi: 0.108)
7. Heart failure NOS (psi: 0.056)
8. Cardiomegaly (psi: 0.007)
AoU Signature 16 (BINARY MATCH) - Top diseases (by psi):
1. Coronary atherosclerosis (psi: 2.570)
2. Myocardial infarction (psi: 2.061)
3. Congestive heart failure (CHF) NOS (psi: 1.965)
4. Cardiomegaly (psi: 1.908)
5. Rheumatic disease of the heart valves (psi: 1.845)
6. Other forms of chronic heart disease (psi: 1.832)
7. Paroxysmal ventricular tachycardia (psi: 1.755)
8. Pulmonary heart disease (psi: 1.726)
AoU Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 2
Binary match: AoU Signature 12 (score: 1.000)
Weighted match: AoU Signature None (score: 0.000)
================================================================================
UKB Signature 2 - Top diseases (by psi):
1. Ulcer of esophagus (psi: 0.661)
2. Stricture and stenosis of esophagus (psi: 0.617)
3. Gastric ulcer (psi: 0.298)
4. Duodenal ulcer (psi: 0.258)
5. Barrett's esophagus (psi: 0.210)
6. Other disorders of stomach and duodenum (psi: 0.157)
7. Esophageal bleeding (varices/hemorrhage) (psi: 0.036)
8. Duodenitis (psi: 0.028)
AoU Signature 12 (BINARY MATCH) - Top diseases (by psi):
1. Esophagitis, GERD and related diseases (psi: 1.714)
2. GERD (psi: 1.683)
3. Hemorrhage of rectum and anus (psi: 1.518)
4. Diverticulosis (psi: 1.477)
5. Iron deficiency anemias, unspecified or not due to blood loss (psi: 1.431)
6. Diaphragmatic hernia (psi: 1.408)
7. Esophageal bleeding (varices/hemorrhage) (psi: 1.357)
8. Hemorrhage of gastrointestinal tract (psi: 1.305)
AoU Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 5
Binary match: AoU Signature 16 (score: 0.714)
Weighted match: AoU Signature None (score: 0.000)
================================================================================
UKB Signature 5 - Top diseases (by psi):
1. Other acute and subacute forms of ischemic heart disease (psi: 0.190)
2. Unstable angina (intermediate coronary syndrome) (psi: 0.150)
3. Coronary atherosclerosis (psi: -0.056)
4. Hypercholesterolemia (psi: -0.329)
5. Angina pectoris (psi: -0.948)
6. Other chronic ischemic heart disease, unspecified (psi: -1.043)
7. Myocardial infarction (psi: -1.183)
AoU Signature 16 (BINARY MATCH) - Top diseases (by psi):
1. Coronary atherosclerosis (psi: 2.570)
2. Myocardial infarction (psi: 2.061)
3. Congestive heart failure (CHF) NOS (psi: 1.965)
4. Cardiomegaly (psi: 1.908)
5. Rheumatic disease of the heart valves (psi: 1.845)
6. Other forms of chronic heart disease (psi: 1.832)
7. Paroxysmal ventricular tachycardia (psi: 1.755)
8. Pulmonary heart disease (psi: 1.726)
AoU Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 7
Binary match: AoU Signature 1 (score: 0.182)
Weighted match: AoU Signature None (score: 0.000)
================================================================================
UKB Signature 7 - Top diseases (by psi):
1. Rheumatism, unspecified and fibrositis (psi: 1.213)
2. Myalgia and myositis unspecified (psi: 0.221)
3. Cervicalgia (psi: 0.161)
4. Sleep apnea (psi: -0.010)
5. Abdominal pain (psi: -0.106)
6. Hyperlipidemia (psi: -0.137)
7. Other inflammatory spondylopathies (psi: -0.190)
8. Migraine (psi: -0.283)
AoU Signature 1 (BINARY MATCH) - Top diseases (by psi):
1. Other anemias (psi: 1.261)
2. Hypopotassemia (psi: 1.170)
3. Bacterial infection NOS (psi: 1.169)
4. Hyposmolality and/or hyponatremia (psi: 1.162)
5. Hypovolemia (psi: 1.108)
6. Acute renal failure (psi: 1.012)
7. E. coli (psi: 0.985)
8. Other acute and subacute forms of ischemic heart disease (psi: 0.972)
AoU Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 13
Binary match: AoU Signature 5 (score: 0.538)
Weighted match: AoU Signature 2 (score: 0.812)
================================================================================
UKB Signature 13 - Top diseases (by psi):
1. Urinary incontinence (psi: 1.867)
2. Other symptoms/disorders or the urinary system (psi: 1.168)
3. Urethral stricture (not specified as infectious) (psi: 1.124)
4. Chronic prostatitis (psi: 1.108)
5. Cystitis (psi: 1.068)
6. Chronic cystitis (psi: 1.052)
7. Bladder neck obstruction (psi: 0.928)
8. Malignant neoplasm of bladder (psi: 0.900)
AoU Signature 5 (BINARY MATCH) - Top diseases (by psi):
1. Hyperplasia of prostate (psi: 1.960)
2. Inguinal hernia (psi: 1.858)
3. Hearing loss (psi: 1.694)
4. Raynaud's syndrome (psi: 1.688)
5. Lipoma of skin and subcutaneous tissue (psi: 1.572)
6. Viral Enteritis (psi: 1.561)
7. Lipoma (psi: 1.480)
8. Disturbances in tooth eruption (psi: 1.449)
AoU Signature 2 (WEIGHTED MATCH) - Top diseases (by psi):
1. Other symptoms/disorders or the urinary system (psi: 2.575)
2. Urinary incontinence (psi: 2.034)
3. Other disorders of bladder (psi: 1.935)
4. Uterine/Uterovaginal prolapse (psi: 1.729)
5. Functional disorders of bladder (psi: 1.655)
6. Urinary tract infection (psi: 1.583)
7. Bladder neck obstruction (psi: 1.518)
8. Cystitis (psi: 1.456)
================================================================================
UKB Signature 14
Binary match: AoU Signature 19 (score: 0.900)
Weighted match: AoU Signature None (score: 0.000)
================================================================================
UKB Signature 14 - Top diseases (by psi):
1. Empyema and pneumothorax (psi: 0.747)
2. Emphysema (psi: 0.118)
3. Bronchiectasis (psi: 0.049)
4. Postinflammatory pulmonary fibrosis (psi: 0.012)
5. Chronic airway obstruction (psi: -0.085)
6. Pneumonia (psi: -0.193)
7. Pneumococcal pneumonia (psi: -0.214)
8. Obstructive chronic bronchitis (psi: -0.310)
AoU Signature 19 (BINARY MATCH) - Top diseases (by psi):
1. Pneumonia (psi: 2.825)
2. Chronic airway obstruction (psi: 2.768)
3. Pleurisy; pleural effusion (psi: 2.126)
4. Pulmonary collapse; interstitial and compensatory emphysema (psi: 1.999)
5. Respiratory failure (psi: 1.848)
6. Emphysema (psi: 1.788)
7. Bronchiectasis (psi: 1.771)
8. Pneumococcal pneumonia (psi: 1.759)
AoU Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 16
Binary match: AoU Signature 1 (score: 0.690)
Weighted match: AoU Signature None (score: 0.000)
================================================================================
UKB Signature 16 - Top diseases (by psi):
1. Peritonitis and retroperitoneal infections (psi: 0.317)
2. Bacterial pneumonia (psi: 0.031)
3. Neutropenia (psi: -0.113)
4. Chronic ulcer of skin (psi: -0.134)
5. Other local infections of skin and subcutaneous tissue (psi: -0.172)
6. Hyperpotassemia (psi: -0.275)
7. Gram negative septicemia (psi: -0.286)
8. E. coli (psi: -0.325)
AoU Signature 1 (BINARY MATCH) - Top diseases (by psi):
1. Other anemias (psi: 1.261)
2. Hypopotassemia (psi: 1.170)
3. Bacterial infection NOS (psi: 1.169)
4. Hyposmolality and/or hyponatremia (psi: 1.162)
5. Hypovolemia (psi: 1.108)
6. Acute renal failure (psi: 1.012)
7. E. coli (psi: 0.985)
8. Other acute and subacute forms of ischemic heart disease (psi: 0.972)
AoU Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
UKB Signature 17
Binary match: AoU Signature 12 (score: 0.765)
Weighted match: AoU Signature None (score: 0.000)
================================================================================
UKB Signature 17 - Top diseases (by psi):
1. Regional enteritis (psi: 0.907)
2. Gastrointestinal complications (psi: 0.709)
3. Neoplasm of unspecified nature of digestive system (psi: 0.658)
4. Hemorrhage of rectum and anus (psi: 0.593)
5. Ventral hernia (psi: 0.508)
6. Ulcerative colitis (psi: 0.330)
7. Malignant neoplasm of rectum, rectosigmoid junction, and anus (psi: 0.146)
8. Colon cancer (psi: 0.075)
AoU Signature 12 (BINARY MATCH) - Top diseases (by psi):
1. Esophagitis, GERD and related diseases (psi: 1.714)
2. GERD (psi: 1.683)
3. Hemorrhage of rectum and anus (psi: 1.518)
4. Diverticulosis (psi: 1.477)
5. Iron deficiency anemias, unspecified or not due to blood loss (psi: 1.431)
6. Diaphragmatic hernia (psi: 1.408)
7. Esophageal bleeding (varices/hemorrhage) (psi: 1.357)
8. Hemorrhage of gastrointestinal tract (psi: 1.305)
AoU Signature (WEIGHTED MATCH): NaN - No match found
================================================================================
INVESTIGATING NaN/0.000000 WEIGHTED SCORES
================================================================================
For signatures with NaN weighted matches, let's check if they have any diseases in common:
UKB Sig 2 ↔ MGB Sig 7 (binary match):
Common diseases: 14
Sample common diseases: ['Duodenitis', 'Ulcer of esophagus', 'Esophagitis, GERD and related diseases', 'Other specified gastritis', 'Duodenal ulcer']
Duodenitis: UKB psi=0.028, MGB psi=1.685
Ulcer of esophagus: UKB psi=0.661, MGB psi=2.317
Esophagitis, GERD and related diseases: UKB psi=-0.375, MGB psi=2.467
Other specified gastritis: UKB psi=-0.859, MGB psi=3.156
Duodenal ulcer: UKB psi=0.258, MGB psi=0.586
UKB Sig 5 ↔ MGB Sig 5 (binary match):
Common diseases: 4
Sample common diseases: ['Unstable angina (intermediate coronary syndrome)', 'Coronary atherosclerosis', 'Other chronic ischemic heart disease, unspecified', 'Myocardial infarction']
Unstable angina (intermediate coronary syndrome): UKB psi=0.150, MGB psi=0.588
Coronary atherosclerosis: UKB psi=-0.056, MGB psi=0.839
Other chronic ischemic heart disease, unspecified: UKB psi=-1.043, MGB psi=-0.786
Myocardial infarction: UKB psi=-1.183, MGB psi=-0.759
UKB Sig 7 ↔ MGB Sig 1 (binary match):
Common diseases: 11
Sample common diseases: ['Cervicalgia', 'Myalgia and myositis unspecified', 'Sleep apnea', 'GERD', 'Major depressive disorder']
Cervicalgia: UKB psi=0.161, MGB psi=-1.169
Myalgia and myositis unspecified: UKB psi=0.221, MGB psi=-1.230
Sleep apnea: UKB psi=-0.010, MGB psi=-1.122
GERD: UKB psi=-1.246, MGB psi=-0.624
Major depressive disorder: UKB psi=-0.896, MGB psi=-0.431
UKB Sig 14 ↔ MGB Sig 14 (binary match):
Common diseases: 9
Sample common diseases: ['Bronchiectasis', 'Empyema and pneumothorax', 'Obstructive chronic bronchitis', 'Pneumonia', 'Emphysema']
Bronchiectasis: UKB psi=0.049, MGB psi=0.494
Empyema and pneumothorax: UKB psi=0.747, MGB psi=2.640
Obstructive chronic bronchitis: UKB psi=-0.310, MGB psi=0.402
Pneumonia: UKB psi=-0.193, MGB psi=1.570
Emphysema: UKB psi=0.118, MGB psi=0.677
UKB Sig 16 ↔ MGB Sig 4 (binary match):
Common diseases: 20
Sample common diseases: ['Orthostatic hypotension', 'Gram negative septicemia', 'Acidosis', 'Other local infections of skin and subcutaneous tissue', 'Neutropenia']
Orthostatic hypotension: UKB psi=-0.523, MGB psi=1.201
Gram negative septicemia: UKB psi=-0.286, MGB psi=0.207
Acidosis: UKB psi=-0.475, MGB psi=0.025
Other local infections of skin and subcutaneous tissue: UKB psi=-0.172, MGB psi=0.644
Neutropenia: UKB psi=-0.113, MGB psi=-0.053
UKB Sig 17 ↔ MGB Sig 8 (binary match):
Common diseases: 15
Sample common diseases: ['Hemorrhage of rectum and anus', 'Other disorders of intestine', 'Hemorrhoids', 'Anal and rectal polyp', 'Ulcerative colitis']
Hemorrhage of rectum and anus: UKB psi=0.593, MGB psi=0.170
Other disorders of intestine: UKB psi=-0.499, MGB psi=0.410
Hemorrhoids: UKB psi=-1.029, MGB psi=0.223
Anal and rectal polyp: UKB psi=-0.831, MGB psi=2.321
Ulcerative colitis: UKB psi=0.330, MGB psi=1.702
Summary: Psi-Based vs Cluster-Based Similarity¶
Key Difference:
- Cluster-based (original): Uses initial cluster assignments (binary, pre-fit), treats all diseases equally
- Psi-based (new): Uses posterior psi from fitted models (continuous, weighted by association strength), recognizes the fit and weights diseases by their cluster-specific affinity
Advantages of Psi-Based Approach:
- Reflects actual learned associations, not just initial assignments
- Weights diseases by strength of association with each signature
- Captures more nuanced relationships (continuous values, not just 0/1)
- Better represents the biological signal learned during training
# Create summary table comparing both approaches
print("="*80)
print("SUMMARY: PSI-BASED vs CLUSTER-BASED SIMILARITY")
print("="*80)
summary_data = []
if stats_ukb_mgb is not None:
summary_data.append({
'Comparison': 'UKB ↔ MGB',
'Method': 'Psi-based (cosine)',
'Median Similarity': f"{stats_ukb_mgb['median_best_match']:.3f}",
'Mean Similarity': f"{stats_ukb_mgb['mean_best_match']:.3f}",
'IQR': f"[{stats_ukb_mgb['q25']:.3f}, {stats_ukb_mgb['q75']:.3f}]"
})
if stats_ukb_aou is not None:
summary_data.append({
'Comparison': 'UKB ↔ AoU',
'Method': 'Psi-based (cosine)',
'Median Similarity': f"{stats_ukb_aou['median_best_match']:.3f}",
'Mean Similarity': f"{stats_ukb_aou['mean_best_match']:.3f}",
'IQR': f"[{stats_ukb_aou['q25']:.3f}, {stats_ukb_aou['q75']:.3f}]"
})
# Add cluster-based results for comparison
if 'median_mgb' in locals():
summary_data.append({
'Comparison': 'UKB ↔ MGB',
'Method': 'Cluster-based (Jaccard)',
'Median Similarity': f"{median_mgb:.3f}",
'Mean Similarity': f"{np.mean(jaccard_mgb):.3f}",
'IQR': f"[{np.percentile(jaccard_mgb, 25):.3f}, {np.percentile(jaccard_mgb, 75):.3f}]"
})
if 'median_aou' in locals():
summary_data.append({
'Comparison': 'UKB ↔ AoU',
'Method': 'Cluster-based (Jaccard)',
'Median Similarity': f"{median_aou:.3f}",
'Mean Similarity': f"{np.mean(jaccard_aou):.3f}",
'IQR': f"[{np.percentile(jaccard_aou, 25):.3f}, {np.percentile(jaccard_aou, 75):.3f}]"
})
if summary_data:
summary_df = pd.DataFrame(summary_data)
print("\n", summary_df.to_string(index=False))
# Save to CSV
summary_df.to_csv(OUTPUT_DIR / 'psi_vs_cluster_similarity_summary.csv', index=False)
print(f"\n✓ Saved comparison summary to: {OUTPUT_DIR / 'psi_vs_cluster_similarity_summary.csv'}")
else:
print("\n⚠ No summary data available (psi loading may have failed)")
print("\n" + "="*80)
print("ANALYSIS COMPLETE")
print("="*80)
Key Findings¶
Strong cross-cohort correspondence: Median modified Jaccard similarity of 0.800 (80.0%) across all UKB signatures
Consistent patterns: Both MGB and AoU show similar levels of correspondence with UKB (83.8% and 77.5% respectively)
Robust biological signatures: The high correspondence suggests that disease signatures represent robust biological patterns that transcend population differences
Methodology: The modified Jaccard similarity uses set intersections to quantify the proportion of diseases in each UKB cluster that are also found in the best-matching cluster in the comparison cohort