Verification: Corrected E Matrix and Prevalence¶
This notebook verifies that the saved files in data_for_running/ match the corrected versions computed from the workflow:
- E_matrix_corrected.pt: Verifies it matches recomputation from E_matrix.pt + censor_info.csv
- prevalence_t_corrected.pt: Verifies it matches recomputation from Y_tensor.pt + E_matrix_corrected.pt
This confirms the workflow was executed correctly.
In [2]:
# ============================================================================
# Setup
# ============================================================================
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.ndimage import gaussian_filter1d
data_dir = Path('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/')
print("Setup complete")
Setup complete
Step 1: Verify E_matrix_corrected.pt¶
In [ ]:
import torch
import pandas as pd
import numpy as np
# Load data
censor_df = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/censor_info.csv')
T = 52
E=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_matrix.pt')
# Convert to timepoints
max_timepoints = torch.tensor(
(censor_df['max_censor'].values - 30).clip(0, T-1).astype(int)
)
# Only update censored cases (where E == T-1)
censored_mask = (E == T - 1) # Shape: (N, D)
# For each patient, cap censored diseases to their max_timepoint
# Expand max_timepoints to match E shape
max_timepoints_expanded = max_timepoints.unsqueeze(1).expand_as(E)
# Update only censored positions
E_corrected = torch.where(
censored_mask,
torch.minimum(E, max_timepoints_expanded),
E # Keep event times as-is
)
================================================================================ STEP 1: VERIFYING E_MATRIX_CORRECTED.PT ================================================================================ ✓ Loaded saved E_matrix_corrected: torch.Size([407878, 348]) ✓ Loaded original E_matrix: torch.Size([407878, 348]) ✓ Loaded censor_info.csv: 407878 patients ✓ Recomputed E_matrix_corrected: torch.Size([407878, 348]) ================================================================================ COMPARISON ================================================================================ ✅ PERFECT MATCH! E_matrix_corrected.pt matches recomputation
In [4]:
# ============================================================================
# Step 1: Verify E_corrected matches recomputation (UKB pattern)
# ============================================================================
print("\n" + "="*80)
print("STEP 1: VERIFYING E_MATRIX_CORRECTED.PT")
print("="*80)
Y=torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt', weights_only=False)
# Load saved E_corrected
E_corrected_saved = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_matrix_corrected.pt', weights_only=False)
if torch.is_tensor(E_corrected_saved):
E_corrected_saved = E_corrected_saved.numpy()
print(f"\n✓ Loaded saved E_matrix_corrected: {E_corrected_saved.shape}")
# Load original E (assuming full follow-up)
E_original = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_matrix.pt', weights_only=False)
if torch.is_tensor(E_original):
E_original = E_original.numpy()
print(f"✓ Loaded original E_matrix: {E_original.shape}")
# Load censor info
censor_df = pd.read_csv('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/censor_info.csv')
print(f"✓ Loaded censor_info.csv: {len(censor_df)} patients")
# Recompute E_corrected using UKB pattern
T = E_original.shape[0] if len(E_original.shape) == 1 else Y.shape[2] # Adjust based on your data
max_timepoints = (censor_df['max_censor'].values - 30).clip(0, T-1).astype(int)
# Only update censored cases (where E == T-1)
censored_mask = (E_original == T - 1) # Shape: (N, D)
# Expand max_timepoints to match E shape
max_timepoints_expanded = np.broadcast_to(
max_timepoints[:, np.newaxis],
E_original.shape
)
# Update only censored positions
E_corrected_recomputed = np.where(
censored_mask,
np.minimum(E_original, max_timepoints_expanded),
E_original # Keep event times as-is
)
print(f"\n✓ Recomputed E_matrix_corrected: {E_corrected_recomputed.shape}")
# Compare
print(f"\n{'='*80}")
print("COMPARISON")
print(f"{'='*80}")
if np.array_equal(E_corrected_saved, E_corrected_recomputed):
print("\n✅ PERFECT MATCH! E_matrix_corrected.pt matches recomputation")
else:
diff = np.abs(E_corrected_saved - E_corrected_recomputed)
print(f"\n⚠️ MISMATCH DETECTED!")
print(f" Max difference: {diff.max()}")
print(f" Number of differences: {(diff > 0).sum()}")
================================================================================ STEP 1: VERIFYING E_MATRIX_CORRECTED.PT ================================================================================ ✓ Loaded saved E_matrix_corrected: (407878, 348) ✓ Loaded original E_matrix: (407878, 348) ✓ Loaded censor_info.csv: 407878 patients ✓ Recomputed E_matrix_corrected: (407878, 348) ================================================================================ COMPARISON ================================================================================ ✅ PERFECT MATCH! E_matrix_corrected.pt matches recomputation
In [8]:
# ============================================================================
# Examples of People with Events (showing only first event per disease)
# ============================================================================
print("\n" + "="*80)
print("EXAMPLES OF PEOPLE WITH EVENTS")
print("="*80)
# Convert Y to numpy if needed
if torch.is_tensor(Y):
Y_np = Y.numpy()
else:
Y_np = Y
if torch.is_tensor(E_corrected):
E_corrected_np = E_corrected.numpy()
else:
E_corrected_np = E_corrected
# Find patients with at least one event
patients_with_events = []
for patient_idx in range(Y_np.shape[0]):
# Check if patient has any events (E < T-1 means event occurred)
if (E_corrected_np[patient_idx, :] < (T - 1)).any():
patients_with_events.append(patient_idx)
print(f"\nTotal patients with at least one event: {len(patients_with_events):,} / {Y_np.shape[0]:,}")
# Show examples
print(f"\n{'='*80}")
print("EXAMPLE PATIENTS WITH EVENTS")
print(f"{'='*80}")
# Load disease names if available (adjust path as needed)
try:
disease_names = pd.read_csv('path/to/disease_names.csv')['name'].values
except:
disease_names = [f"Disease {i}" for i in range(Y_np.shape[1])]
# Show first 10 patients with events
for example_idx, patient_idx in enumerate(patients_with_events[:10]):
print(f"\nPatient {patient_idx}:")
# Find diseases with events for this patient
event_diseases = []
for disease_idx in range(Y_np.shape[1]):
event_timepoint = E_corrected_np[patient_idx, disease_idx]
if event_timepoint < (T - 1): # Event occurred (not censored at max time)
# Find FIRST event in Y (not all events)
Y_events = np.where(Y_np[patient_idx, disease_idx, :] == 1)[0]
if len(Y_events) > 0:
first_event_timepoint = Y_events[0]
event_diseases.append((disease_idx, event_timepoint, first_event_timepoint))
print(f" Number of diseases with events: {len(event_diseases)}")
# Show first 5 diseases (or all if <= 5)
for disease_idx, event_timepoint, first_event_timepoint in event_diseases[:5]:
disease_name = disease_names[disease_idx] if disease_idx < len(disease_names) else f"Disease {disease_idx}"
age_at_event = 30 + event_timepoint
print(f" Disease {disease_idx} ({disease_name}):")
print(f" E[{patient_idx}, {disease_idx}] = {event_timepoint} (age {age_at_event})")
print(f" Y[{patient_idx}, {disease_idx}, :] has events at timepoints: [{first_event_timepoint}]")
# Verify first event matches E
first_event_match = (first_event_timepoint == event_timepoint)
print(f" ✓ First event in Y matches E: {first_event_match}")
if not first_event_match:
print(f" ⚠️ WARNING: First event at timepoint {first_event_timepoint}, but E = {event_timepoint}")
if len(event_diseases) > 5:
print(f" ... and {len(event_diseases) - 5} more diseases with events")
================================================================================
EXAMPLES OF PEOPLE WITH EVENTS
================================================================================
Total patients with at least one event: 407,137 / 407,878
================================================================================
EXAMPLE PATIENTS WITH EVENTS
================================================================================
Patient 0:
Number of diseases with events: 2
Disease 161 (Disease 161):
E[0, 161] = 44 (age 74)
Y[0, 161, :] has events at timepoints: [44]
✓ First event in Y matches E: True
Disease 326 (Disease 326):
E[0, 326] = 44 (age 74)
Y[0, 326, :] has events at timepoints: [44]
✓ First event in Y matches E: True
Patient 1:
Number of diseases with events: 2
Disease 150 (Disease 150):
E[1, 150] = 22 (age 52)
Y[1, 150, :] has events at timepoints: [22]
✓ First event in Y matches E: True
Disease 157 (Disease 157):
E[1, 157] = 23 (age 53)
Y[1, 157, :] has events at timepoints: [23]
✓ First event in Y matches E: True
Patient 2:
Number of diseases with events: 3
Disease 180 (Disease 180):
E[2, 180] = 40 (age 70)
Y[2, 180, :] has events at timepoints: [40]
✓ First event in Y matches E: True
Disease 195 (Disease 195):
E[2, 195] = 40 (age 70)
Y[2, 195, :] has events at timepoints: [40]
✓ First event in Y matches E: True
Disease 299 (Disease 299):
E[2, 299] = 42 (age 72)
Y[2, 299, :] has events at timepoints: [42]
✓ First event in Y matches E: True
Patient 3:
Number of diseases with events: 26
Disease 10 (Disease 10):
E[3, 10] = 46 (age 76)
Y[3, 10, :] has events at timepoints: [46]
✓ First event in Y matches E: True
Disease 11 (Disease 11):
E[3, 11] = 46 (age 76)
Y[3, 11, :] has events at timepoints: [46]
✓ First event in Y matches E: True
Disease 21 (Disease 21):
E[3, 21] = 40 (age 70)
Y[3, 21, :] has events at timepoints: [40]
✓ First event in Y matches E: True
Disease 33 (Disease 33):
E[3, 33] = 46 (age 76)
Y[3, 33, :] has events at timepoints: [46]
✓ First event in Y matches E: True
Disease 55 (Disease 55):
E[3, 55] = 46 (age 76)
Y[3, 55, :] has events at timepoints: [46]
✓ First event in Y matches E: True
... and 21 more diseases with events
Patient 4:
Number of diseases with events: 24
Disease 37 (Disease 37):
E[4, 37] = 13 (age 43)
Y[4, 37, :] has events at timepoints: [13]
✓ First event in Y matches E: True
Disease 43 (Disease 43):
E[4, 43] = 37 (age 67)
Y[4, 43, :] has events at timepoints: [37]
✓ First event in Y matches E: True
Disease 45 (Disease 45):
E[4, 45] = 21 (age 51)
Y[4, 45, :] has events at timepoints: [21]
✓ First event in Y matches E: True
Disease 47 (Disease 47):
E[4, 47] = 37 (age 67)
Y[4, 47, :] has events at timepoints: [37]
✓ First event in Y matches E: True
Disease 60 (Disease 60):
E[4, 60] = 35 (age 65)
Y[4, 60, :] has events at timepoints: [35]
✓ First event in Y matches E: True
... and 19 more diseases with events
Patient 5:
Number of diseases with events: 1
Disease 52 (Disease 52):
E[5, 52] = 31 (age 61)
Y[5, 52, :] has events at timepoints: [31]
✓ First event in Y matches E: True
Patient 6:
Number of diseases with events: 32
Disease 14 (Disease 14):
E[6, 14] = 41 (age 71)
Y[6, 14, :] has events at timepoints: [41]
✓ First event in Y matches E: True
Disease 30 (Disease 30):
E[6, 30] = 43 (age 73)
Y[6, 30, :] has events at timepoints: [43]
✓ First event in Y matches E: True
Disease 45 (Disease 45):
E[6, 45] = 45 (age 75)
Y[6, 45, :] has events at timepoints: [45]
✓ First event in Y matches E: True
Disease 50 (Disease 50):
E[6, 50] = 43 (age 73)
Y[6, 50, :] has events at timepoints: [43]
✓ First event in Y matches E: True
Disease 52 (Disease 52):
E[6, 52] = 43 (age 73)
Y[6, 52, :] has events at timepoints: [43]
✓ First event in Y matches E: True
... and 27 more diseases with events
Patient 7:
Number of diseases with events: 10
Disease 1 (Disease 1):
E[7, 1] = 27 (age 57)
Y[7, 1, :] has events at timepoints: [27]
✓ First event in Y matches E: True
Disease 33 (Disease 33):
E[7, 33] = 19 (age 49)
Y[7, 33, :] has events at timepoints: [19]
✓ First event in Y matches E: True
Disease 162 (Disease 162):
E[7, 162] = 24 (age 54)
Y[7, 162, :] has events at timepoints: [24]
✓ First event in Y matches E: True
Disease 206 (Disease 206):
E[7, 206] = 27 (age 57)
Y[7, 206, :] has events at timepoints: [27]
✓ First event in Y matches E: True
Disease 208 (Disease 208):
E[7, 208] = 19 (age 49)
Y[7, 208, :] has events at timepoints: [19]
✓ First event in Y matches E: True
... and 5 more diseases with events
Patient 8:
Number of diseases with events: 4
Disease 261 (Disease 261):
E[8, 261] = 12 (age 42)
Y[8, 261, :] has events at timepoints: [12]
✓ First event in Y matches E: True
Disease 266 (Disease 266):
E[8, 266] = 11 (age 41)
Y[8, 266, :] has events at timepoints: [11]
✓ First event in Y matches E: True
Disease 273 (Disease 273):
E[8, 273] = 11 (age 41)
Y[8, 273, :] has events at timepoints: [11]
✓ First event in Y matches E: True
Disease 279 (Disease 279):
E[8, 279] = 11 (age 41)
Y[8, 279, :] has events at timepoints: [11]
✓ First event in Y matches E: True
Patient 9:
Number of diseases with events: 15
Disease 60 (Disease 60):
E[9, 60] = 46 (age 76)
Y[9, 60, :] has events at timepoints: [46]
✓ First event in Y matches E: True
Disease 61 (Disease 61):
E[9, 61] = 46 (age 76)
Y[9, 61, :] has events at timepoints: [46]
✓ First event in Y matches E: True
Disease 68 (Disease 68):
E[9, 68] = 46 (age 76)
Y[9, 68, :] has events at timepoints: [46]
✓ First event in Y matches E: True
Disease 112 (Disease 112):
E[9, 112] = 46 (age 76)
Y[9, 112, :] has events at timepoints: [46]
✓ First event in Y matches E: True
Disease 117 (Disease 117):
E[9, 117] = 46 (age 76)
Y[9, 117, :] has events at timepoints: [46]
✓ First event in Y matches E: True
... and 10 more diseases with events
Step 2: Verify prevalence_t_corrected.pt¶
In [3]:
# ============================================================================
# Step 2: Load saved prevalence_t_corrected and recompute from scratch
# ============================================================================
print("\n" + "="*80)
print("STEP 2: VERIFYING PREVALENCE_T_CORRECTED.PT")
print("="*80)
# Load saved corrected prevalence
prevalence_t_saved = torch.load(str(data_dir / 'prevalence_t_corrected.pt'), weights_only=False)
if torch.is_tensor(prevalence_t_saved):
prevalence_t_saved = prevalence_t_saved.numpy()
print(f"\n✓ Loaded saved prevalence_t_corrected: {prevalence_t_saved.shape}")
# Load Y and E_corrected to recompute
Y = torch.load(str(data_dir / 'Y_tensor.pt'), weights_only=False)
E_corrected = torch.load(str(data_dir / 'E_matrix_corrected.pt'), weights_only=False)
censor_df = pd.read_csv(str(data_dir / 'censor_info.csv'))
enrollment_ages = censor_df['age'].values
print(f"✓ Loaded Y_tensor: {Y.shape}")
print(f"✓ Loaded E_matrix_corrected: {E_corrected.shape}")
print(f"✓ Loaded enrollment ages: {len(enrollment_ages)} patients")
# Recompute prevalence using the same function
def compute_smoothed_prevalence_at_risk(Y, E_corrected, enrollment_ages, window_size=5, smooth_on_logit=True):
"""
Compute smoothed prevalence with proper at-risk filtering.
Parameters:
-----------
Y : torch.Tensor (N × D × T)
E_corrected : torch.Tensor (N × D) - corrected event/censor times
enrollment_ages : np.ndarray (N,) - enrollment ages for each person
window_size : int - Gaussian smoothing window size
smooth_on_logit : bool - Smooth on logit scale
"""
if torch.is_tensor(Y):
Y = Y.numpy()
if torch.is_tensor(E_corrected):
E_corrected = E_corrected.numpy()
N, D, T = Y.shape
prevalence_t = np.zeros((D, T))
# Convert timepoints to ages (assuming timepoint 0 = age 30)
timepoint_ages = np.arange(T) + 30
print(f"\n Computing prevalence for {D} diseases, {T} timepoints...")
# Convert E_corrected to numpy if needed
if torch.is_tensor(E_corrected):
E_corrected_np = E_corrected.numpy()
else:
E_corrected_np = E_corrected
for d in range(D):
if d % 50 == 0:
print(f" Processing disease {d}/{D}...")
for t in range(T):
age_t = timepoint_ages[t]
# Only include people who are still at risk at timepoint t
at_risk_mask = (E_corrected_np[:, d] >= t)
if at_risk_mask.sum() > 0:
if torch.is_tensor(Y):
prevalence_t[d, t] = Y[at_risk_mask, d, t].numpy().mean()
else:
prevalence_t[d, t] = Y[at_risk_mask, d, t].mean()
else:
prevalence_t[d, t] = np.nan
# Smooth as before
if smooth_on_logit:
epsilon = 1e-8
# Handle NaN values
valid_mask = ~np.isnan(prevalence_t[d, :])
if valid_mask.sum() > 0:
logit_prev = np.full(T, np.nan)
logit_prev[valid_mask] = np.log(
(prevalence_t[d, valid_mask] + epsilon) /
(1 - prevalence_t[d, valid_mask] + epsilon)
)
# Smooth only valid values
smoothed_logit = gaussian_filter1d(
np.nan_to_num(logit_prev, nan=0),
sigma=window_size
)
# Restore NaN where original was NaN
smoothed_logit[~valid_mask] = np.nan
prevalence_t[d, :] = 1 / (1 + np.exp(-smoothed_logit))
else:
prevalence_t[d, :] = gaussian_filter1d(
np.nan_to_num(prevalence_t[d, :], nan=0),
sigma=window_size
)
return prevalence_t
# Recompute
print("\nRecomputing prevalence_t_corrected...")
prevalence_t_recomputed = compute_smoothed_prevalence_at_risk(
Y=Y,
E_corrected=E_corrected,
enrollment_ages=enrollment_ages,
window_size=5,
smooth_on_logit=True
)
print(f"\n✓ Recomputed prevalence_t_corrected: {prevalence_t_recomputed.shape}")
================================================================================
STEP 2: VERIFYING PREVALENCE_T_CORRECTED.PT
================================================================================
✓ Loaded saved prevalence_t_corrected: (348, 52)
✓ Loaded Y_tensor: torch.Size([407878, 348, 52])
✓ Loaded E_matrix_corrected: torch.Size([407878, 348])
✓ Loaded enrollment ages: 407878 patients
Recomputing prevalence_t_corrected...
Computing prevalence for 348 diseases, 52 timepoints...
Processing disease 0/348...
Processing disease 50/348...
Processing disease 100/348...
Processing disease 150/348...
Processing disease 200/348...
Processing disease 250/348...
Processing disease 300/348...
✓ Recomputed prevalence_t_corrected: (348, 52)
In [4]:
# Compare recomputed prevalence with saved version
print(f"\n{'='*80}")
print("COMPARISON")
print(f"{'='*80}")
# Handle NaN values in comparison
valid_mask = ~(np.isnan(prevalence_t_saved) | np.isnan(prevalence_t_recomputed))
if valid_mask.sum() == 0:
print("\n⚠️ No valid values to compare (all NaN)")
else:
# Compare only valid values
diff = np.abs(prevalence_t_saved - prevalence_t_recomputed)
max_diff = np.nanmax(diff)
mean_diff = np.nanmean(diff[valid_mask])
# Check if they're close (allowing for small numerical differences from smoothing)
close_match = np.allclose(prevalence_t_saved, prevalence_t_recomputed,
equal_nan=True, rtol=1e-5, atol=1e-6)
print(f"\nComparison statistics:")
print(f" Valid elements: {valid_mask.sum()} / {prevalence_t_saved.size}")
print(f" Max absolute difference: {max_diff:.8f}")
print(f" Mean absolute difference: {mean_diff:.8f}")
if close_match:
print(f"\n✅ CLOSE MATCH! prevalence_t_corrected.pt matches recomputation")
print(f" (Small differences expected due to numerical precision in smoothing)")
else:
print(f"\n⚠️ SIGNIFICANT DIFFERENCES DETECTED!")
print(f" Max diff: {max_diff:.8f}, Mean diff: {mean_diff:.8f}")
# Show some examples of differences
large_diff_mask = diff > 0.01
if large_diff_mask.sum() > 0:
print(f"\n Found {large_diff_mask.sum()} elements with difference > 0.01")
large_diff_indices = np.where(large_diff_mask)
print(f" Example differences:")
for i in range(min(5, len(large_diff_indices[0]))):
d_idx = large_diff_indices[0][i]
t_idx = large_diff_indices[1][i]
print(f" Disease {d_idx}, Timepoint {t_idx}: "
f"Saved={prevalence_t_saved[d_idx, t_idx]:.6f}, "
f"Recomputed={prevalence_t_recomputed[d_idx, t_idx]:.6f}, "
f"Diff={diff[d_idx, t_idx]:.6f}")
================================================================================ COMPARISON ================================================================================ Comparison statistics: Valid elements: 18096 / 18096 Max absolute difference: 0.00000000 Mean absolute difference: 0.00000000 ✅ CLOSE MATCH! prevalence_t_corrected.pt matches recomputation (Small differences expected due to numerical precision in smoothing)
Summary¶
In [5]:
# Summary
print(f"\n{'='*80}")
print("VERIFICATION SUMMARY")
print(f"{'='*80}")
print("\n✓ Step 1: E_matrix_corrected.pt verification complete")
print("✓ Step 2: prevalence_t_corrected.pt verification complete")
print("\n" + "="*80)
print("If both verifications passed, the saved files match the workflow!")
print("="*80)
================================================================================ VERIFICATION SUMMARY ================================================================================ ✓ Step 1: E_matrix_corrected.pt verification complete ✓ Step 2: prevalence_t_corrected.pt verification complete ================================================================================ If both verifications passed, the saved files match the workflow! ================================================================================
In [ ]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
# Load data
cov_df = pd.read_csv("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/baselinagefamh_withpcs.csv")[:10000]
# Load Y from checkpoint
checkpoint = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized/enrollment_model_W0.0001_batch_0_10000.pt", weights_only=False)
Y = checkpoint['Y']
del checkpoint
# Load corrected E matrix for proper at-risk filtering
E_corrected = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_matrix_corrected.pt", weights_only=False)
# Subset to match the batch size (first 10000 patients)
E_corrected = E_corrected[:10000]
print(f"✓ Loaded Y: {Y.shape}")
print(f"✓ Loaded E_corrected: {E_corrected.shape}")
print(f"✓ Loaded cov_df: {len(cov_df)} patients")
In [1]:
# ============================================================================
# Full Dataset Calibration: 400k patients using pre-computed pi
# ============================================================================
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
# Load pre-computed pi (full 400k dataset)
#pi_full = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox/censor_e_batchrun_vectorized/pi_fullmode_400k.pt", map_location='cpu', weights_only=False)[:50000]
pi_full = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_correctedE_vectorized/pi_enroll_fixedphi_sex_FULL.pt",
)[:50000]
print(f"✓ Loaded pre-computed pi: {pi_full.shape}")
# Load Y (full dataset)
Y_full = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/Y_tensor.pt",
map_location='cpu', weights_only=False)[:50000]
print(f"✓ Loaded Y: {Y_full.shape}")
# Load corrected E matrix (full dataset)
E_corrected_full = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/E_matrix_corrected.pt",
map_location='cpu', weights_only=False)[:50000]
print(f"✓ Loaded E_corrected: {E_corrected_full.shape}")
# Load covariates (full dataset)
cov_df_full = pd.read_csv("/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/baselinagefamh_withpcs.csv")[:50000]
print(f"✓ Loaded cov_df: {len(cov_df_full)} patients")
# Convert to numpy
pi_np = pi_full.detach().numpy()
Y_np = Y_full.detach().numpy()
if torch.is_tensor(E_corrected_full):
E_corrected_np = E_corrected_full.detach().numpy()
else:
E_corrected_np = E_corrected_full
N, D, T = Y_np.shape
print(f"\nDataset dimensions: {N} patients × {D} diseases × {T} timepoints")
# Create at_risk mask using corrected E matrix
print("\nCreating at-risk mask...")
at_risk = np.zeros((N, D, T), dtype=bool)
for n in range(N):
if n % 50000 == 0:
print(f" Processing patient {n}/{N}...")
for d in range(D):
# Patient is at risk at timepoint t if event/censor time >= t
at_risk[n, d, :] = (E_corrected_np[n, d] >= np.arange(T))
print("✓ At-risk mask created")
# Collect all predictions and observations (at-risk only)
print("\nCollecting predictions and observations...")
all_pred = []
all_obs = []
for t in range(T):
if t % 10 == 0:
print(f" Processing timepoint {t}/{T}...")
mask_t = at_risk[:,:,t]
if mask_t.sum() > 0:
all_pred.extend(pi_np[:,:,t][mask_t])
all_obs.extend(Y_np[:,:,t][mask_t])
all_pred = np.array(all_pred)
all_obs = np.array(all_obs)
print(f"\n✓ Collected {len(all_pred):,} predictions/observations")
print(f" Mean predicted: {all_pred.mean():.2e}")
print(f" Mean observed: {all_obs.mean():.2e}")
# Create calibration plot
print("\nCreating calibration plot...")
fig, ax = plt.subplots(figsize=(12, 10), dpi=300)
# Create bins in log space
n_bins = 50
min_bin_count = 10000 # Higher threshold for full dataset
use_log_scale = True
if use_log_scale:
bin_edges = np.logspace(np.log10(max(1e-7, min(all_pred))),
np.log10(max(all_pred)),
n_bins + 1)
else:
bin_edges = np.linspace(min(all_pred), max(all_pred), n_bins + 1)
# Calculate statistics for each bin
bin_means = []
obs_means = []
counts = []
for i in range(n_bins):
mask = (all_pred >= bin_edges[i]) & (all_pred < bin_edges[i + 1])
if np.sum(mask) >= min_bin_count:
bin_means.append(np.mean(all_pred[mask]))
obs_means.append(np.mean(all_obs[mask]))
counts.append(np.sum(mask))
# Plot
if use_log_scale:
ax.plot([1e-7, 1], [1e-7, 1], '--', color='gray', alpha=0.5, label='Perfect calibration', linewidth=2)
ax.set_xscale('log')
ax.set_yscale('log')
else:
ax.plot([0, max(all_pred)], [0, max(all_pred)], '--', color='gray', alpha=0.5, label='Perfect calibration', linewidth=2)
ax.plot(bin_means, obs_means, 'o-', color='#1f77b4',
markersize=10, linewidth=2.5, label='Observed rates', alpha=0.8)
# Add counts as annotations
for i, (x, y, c) in enumerate(zip(bin_means, obs_means, counts)):
ax.annotate(f'n={c:,}', (x, y), xytext=(0, 12),
textcoords='offset points', ha='center', fontsize=9)
# Add summary statistics
mse = np.mean((np.array(bin_means) - np.array(obs_means))**2)
mean_pred = np.mean(all_pred)
mean_obs = np.mean(all_obs)
stats_text = f'MSE: {mse:.2e}\n'
stats_text += f'Mean Predicted: {mean_pred:.2e}\n'
stats_text += f'Mean Observed: {mean_obs:.2e}\n'
stats_text += f'N total: {sum(counts):,}'
ax.text(0.05, 0.95, stats_text,
transform=ax.transAxes,
verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.9),
fontsize=11)
ax.grid(True, which='both', linestyle='--', alpha=0.3)
ax.set_xlabel('Predicted Event Rate', fontsize=14, fontweight='bold')
ax.set_ylabel('Observed Event Rate', fontsize=14, fontweight='bold')
ax.set_title('Calibration Across All Follow-up (At-Risk Only)\nFull Dataset (400k patients)',
fontsize=16, fontweight='bold', pad=20)
ax.legend(loc='lower right', fontsize=12)
plt.tight_layout()
# Save plot
save_path = "/Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/new_notebooks/reviewer_responses/notebooks/R3/calibration_plots_full_400k.pdf"
plt.savefig(save_path, format='pdf', dpi=300, bbox_inches='tight')
print(f"\n✓ Saved calibration plot to: {save_path}")
plt.show()
/var/folders/fl/ng5crz0x0fnb6c6x8dk7tfth0000gn/T/ipykernel_95235/2192749160.py:14: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
pi_full = torch.load("/Users/sarahurbut/Library/CloudStorage/Dropbox/enrollment_predictions_fixedphi_correctedE_vectorized/pi_enroll_fixedphi_sex_FULL.pt",
✓ Loaded pre-computed pi: torch.Size([50000, 348, 52]) ✓ Loaded Y: torch.Size([50000, 348, 52]) ✓ Loaded E_corrected: torch.Size([50000, 348]) ✓ Loaded cov_df: 50000 patients Dataset dimensions: 50000 patients × 348 diseases × 52 timepoints Creating at-risk mask... Processing patient 0/50000... ✓ At-risk mask created Collecting predictions and observations... Processing timepoint 0/52... Processing timepoint 10/52... Processing timepoint 20/52... Processing timepoint 30/52... Processing timepoint 40/52... Processing timepoint 50/52... ✓ Collected 738,708,915 predictions/observations Mean predicted: 4.14e-04 Mean observed: 5.45e-04 Creating calibration plot... ✓ Saved calibration plot to: /Users/sarahurbut/aladynoulli2/pyScripts/dec_6_revision/new_notebooks/reviewer_responses/notebooks/R3/calibration_plots_full_400k.pdf