Preprocessing Pipeline: Create Initialization Files¶

This notebook shows how to create all the initialization files needed to run the Aladynoulli model:

  1. Prevalence (prevalence_t) - Disease prevalence over time (smoothed)
  2. Clusters (initial_clusters_400k.pt) - Disease-to-signature assignments (via spectral clustering)
  3. Psi (initial_psi_400k.pt) - Initial signature-disease association parameters (based on clusters)
  4. Reference Trajectories (reference_trajectories.pt) - Population-level signature trajectories (smoothed)

Key Point: Clusters and Psi are created automatically when you initialize the model. You just need to:

  • Compute prevalence_t from Y
  • Initialize the model (which creates clusters and psi)
  • Create reference trajectories from clusters

Based on with_bigdata.ipynb and minimalreprobigforinit.ipynb

1. Setup and Imports¶

✓ Imports complete

Workflow Overview¶

Step 1: Compute prevalence_t (smoothed disease prevalence over time)
Step 2: Initialize model → automatically creates clusters and psi
Step 3: Create reference trajectories from clusters

That's it! The model handles clusters and psi creation automatically.

Step 1: Compute Prevalence (prevalence_t)¶

Compute smoothed disease prevalence over time. This is used by the model for initialization.

2. Load Data¶

Load Y, E, G tensors and any other required data.

✓ Loaded data:
  Y shape: torch.Size([407878, 348, 52]) (patients × diseases × timepoints)
  E shape: torch.Size([407878, 348]) (patients × diseases)
  G shape: torch.Size([407878, 36]) (patients × genetic variants)
  K (signatures): 20

Step 1: Compute Prevalence¶

Calculate disease prevalence over time with Gaussian smoothing.

Note on smoothing scale:

  • Logit scale (default): Better for rare events, preserves relative differences. Recommended for most cases.
  • Probability scale: More intuitive, but can be problematic for rare events.

The function compute_smoothed_prevalence_at_risk() is imported from preprocessing_utils.py - no need to define it here!

  Computing prevalence for 348 diseases, 52 timepoints...
    Processing disease 0/348...
    Processing disease 50/348...
    Processing disease 100/348...
    Processing disease 150/348...
    Processing disease 200/348...
    Processing disease 250/348...
    Processing disease 300/348...
✓ Computed prevalence:
  Shape: (348, 52) (diseases × timepoints)
  Range: [0.0000, 0.0376]
  Method: Smoothed on logit scale (better for rare events)
  Note: To recreate exact same psi, must use same prevalence_t as original
✓ Loaded disease names
No description has been provided for this image
✓ Visualized smoothed prevalence for 6 example diseases

Step 2: Create Clusters and Psi¶

Create disease clusters and initial psi parameters using spectral clustering.

Key: The function create_initial_clusters_and_psi() is imported from preprocessing_utils.py - no model initialization needed!

What happens:¶

  1. Compute disease similarity matrix from Y (correlation of logit-transformed prevalence)
  2. Spectral clustering assigns diseases to signatures (clusters) using random_state=42
  3. Psi is initialized: psi[k, d] = 1.0 + noise if disease d is in cluster k, else psi[k, d] = -2.0 + noise

Benefits: No model initialization needed - much faster!

Step 3: Create Reference Trajectories¶

Create population-level signature reference trajectories using LOWESS smoothing on logit scale.

What this does:

  1. For each signature, compute proportion of diseases in that signature over time
  2. Smooth using LOWESS on logit scale
  3. Create healthy reference trajectory

Uses: Clusters from Step 2

Verification: Clusters Match Model Initialization¶

Confirmation: The clusters from create_initial_clusters_and_psi() will match exactly what the model produces because:

  1. Same Y_avg computation: torch.mean(Y, dim=2), clamp, logit transform
  2. Same correlation matrix: torch.corrcoef(Y_avg.T) with NaN handling
  3. Same similarity matrix: (Y_corr + 1) / 2
  4. Same spectral clustering: random_state=42, n_init=10, assign_labels='kmeans'
  5. Deterministic: Given the same inputs and random_state, spectral clustering produces identical cluster assignmentsNote: Psi values may differ slightly due to random number generation, but clusters will be identical.
  Computing prevalence for 348 diseases, 52 timepoints...
    Processing disease 0/348...
    Processing disease 50/348...
    Processing disease 100/348...
    Processing disease 150/348...
    Processing disease 200/348...
    Processing disease 250/348...
    Processing disease 300/348...
✓ Prevalence shape: (348, 52)

Cluster Sizes:
Cluster 0: 16 diseases
Cluster 1: 20 diseases
Cluster 2: 8 diseases
Cluster 3: 18 diseases
Cluster 4: 9 diseases
Cluster 5: 8 diseases
Cluster 6: 24 diseases
Cluster 7: 7 diseases
Cluster 8: 28 diseases
Cluster 9: 12 diseases
Cluster 10: 11 diseases
Cluster 11: 83 diseases
Cluster 12: 12 diseases
Cluster 13: 5 diseases
Cluster 14: 16 diseases
Cluster 15: 5 diseases
Cluster 16: 21 diseases
Cluster 17: 10 diseases
Cluster 18: 7 diseases
Cluster 19: 28 diseases
✓ Clusters shape: (348,)
✓ Psi shape: torch.Size([20, 348])
✓ Signature refs shape: torch.Size([20, 52])
✓ Healthy ref shape: torch.Size([52])

Summary: Files Created¶

After running Steps 1-3, you have created:

  1. prevalence_t (torch.Tensor, shape: D × T) - Smoothed disease prevalence over time
  2. clusters (numpy array, shape: D) - Disease-to-signature cluster assignments
  3. psi (torch.Tensor, shape: K × D or K+1 × D) - Signature-disease association parameters
  4. signature_refs (torch.Tensor, shape: K × T) - Reference trajectories for each signature (logit scale)
  5. healthy_ref (torch.Tensor, shape: T) - Healthy reference trajectory (logit scale)

Save these files:

  • initial_clusters_400k.pt - clusters
  • initial_psi_400k.pt - psi
  • reference_trajectories.pt - signature_refs and healthy_ref
  • model_essentials.pt - prevalence_t, disease_names, and metadata

Now you can rerun the model ini either predict or batch mode with these itmes!¶

/var/folders/fl/ng5crz0x0fnb6c6x8dk7tfth0000gn/T/ipykernel_88013/2398979547.py:9: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  initial_psi = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/initial_psi_400k.pt')
/var/folders/fl/ng5crz0x0fnb6c6x8dk7tfth0000gn/T/ipykernel_88013/2398979547.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  initial_clusters = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/initial_clusters_400k.pt')
/var/folders/fl/ng5crz0x0fnb6c6x8dk7tfth0000gn/T/ipykernel_88013/2398979547.py:11: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  prevalence_t = torch.load('/Users/sarahurbut/Library/CloudStorage/Dropbox-Personal/data_for_running/prevalence_t_corrected.pt')
/Users/sarahurbut/aladynoulli2/pyScripts_forPublish/clust_huge_amp_vectorized.py:76: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.signature_refs = torch.tensor(signature_references, dtype=torch.float32)
/Users/sarahurbut/aladynoulli2/pyScripts_forPublish/clust_huge_amp_vectorized.py:83: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.G = torch.tensor(G, dtype=torch.float32)
/Users/sarahurbut/aladynoulli2/pyScripts_forPublish/clust_huge_amp_vectorized.py:86: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.G = torch.tensor(G_scaled, dtype=torch.float32)
/Users/sarahurbut/aladynoulli2/pyScripts_forPublish/clust_huge_amp_vectorized.py:88: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.Y = torch.tensor(Y, dtype=torch.float32)
/Users/sarahurbut/aladynoulli2/pyScripts_forPublish/clust_huge_amp_vectorized.py:91: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.prevalence_t = torch.tensor(prevalence_t, dtype=torch.float32)
Cluster Sizes:
Cluster 0: 16 diseases
Cluster 1: 20 diseases
Cluster 2: 8 diseases
Cluster 3: 18 diseases
Cluster 4: 9 diseases
Cluster 5: 8 diseases
Cluster 6: 24 diseases
Cluster 7: 7 diseases
Cluster 8: 28 diseases
Cluster 9: 12 diseases
Cluster 10: 11 diseases
Cluster 11: 83 diseases
Cluster 12: 12 diseases
Cluster 13: 5 diseases
Cluster 14: 16 diseases
Cluster 15: 5 diseases
Cluster 16: 21 diseases
Cluster 17: 10 diseases
Cluster 18: 7 diseases
Cluster 19: 28 diseases

Calculating gamma for k=0:
Number of diseases in cluster: 16
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -11.9623, -13.8155])
Base value centered (first 5): tensor([-0.1935, -0.1935, -0.1935,  1.6597, -0.1935])
Base value centered mean: -2.2248955247050617e-06
Gamma init for k=0 (first 5): tensor([-7.6671e-06,  2.4498e-04, -2.5758e-03,  3.7284e-03,  3.1849e-02])

Calculating gamma for k=1:
Number of diseases in cluster: 20
Base value (first 5): tensor([-13.3213, -13.8155, -13.8155, -13.8155, -10.8504])
Base value centered (first 5): tensor([-0.0469, -0.5411, -0.5411, -0.5411,  2.4240])
Base value centered mean: 1.2302360801186296e-06
Gamma init for k=1 (first 5): tensor([-0.0025, -0.0011, -0.0042,  0.0276,  0.0060])

Calculating gamma for k=2:
Number of diseases in cluster: 8
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1586, -0.1586, -0.1586, -0.1586, -0.1586])
Base value centered mean: 1.22929847634623e-07
Gamma init for k=2 (first 5): tensor([ 0.0008,  0.0023,  0.0010, -0.0006,  0.0002])

Calculating gamma for k=3:
Number of diseases in cluster: 18
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155,  -9.9719, -12.1682])
Base value centered (first 5): tensor([-0.4029, -0.4029, -0.4029,  3.4407,  1.2443])
Base value centered mean: -7.372049708465056e-07
Gamma init for k=3 (first 5): tensor([-0.0017,  0.0007, -0.0002,  0.0071,  0.0045])

Calculating gamma for k=4:
Number of diseases in cluster: 9
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1585, -0.1585, -0.1585, -0.1585, -0.1585])
Base value centered mean: -8.551031527304076e-08
Gamma init for k=4 (first 5): tensor([-0.0034,  0.0035, -0.0015,  0.0030, -0.0003])

Calculating gamma for k=5:
Number of diseases in cluster: 8
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1156, -0.1156, -0.1156, -0.1156, -0.1156])
Base value centered mean: -1.4170835811455618e-06
Gamma init for k=5 (first 5): tensor([-1.5859e-03,  8.5422e-05,  2.4592e-03,  1.5822e-03,  6.4423e-03])

Calculating gamma for k=6:
Number of diseases in cluster: 24
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.4037])
Base value centered (first 5): tensor([-0.1708, -0.1708, -0.1708, -0.1708,  0.2410])
Base value centered mean: -1.818097302930255e-06
Gamma init for k=6 (first 5): tensor([-0.0017,  0.0011, -0.0011,  0.0015,  0.0022])

Calculating gamma for k=7:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7188, -0.7188, -0.7188, -0.7188, -0.7188])
Base value centered mean: -9.069723887478176e-07
Gamma init for k=7 (first 5): tensor([-0.0055, -0.0013,  0.0016,  0.0026,  0.0089])

Calculating gamma for k=8:
Number of diseases in cluster: 28
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -12.0506, -13.1095])
Base value centered (first 5): tensor([-0.2582, -0.2582, -0.2582,  1.5068,  0.4478])
Base value centered mean: 9.576537962630027e-08
Gamma init for k=8 (first 5): tensor([-0.0018, -0.0001,  0.0034,  0.0056,  0.0046])

Calculating gamma for k=9:
Number of diseases in cluster: 12
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.5209, -12.9919])
Base value centered (first 5): tensor([-0.1712, -0.1712, -0.1712,  3.1234,  0.6524])
Base value centered mean: 1.430266024726734e-06
Gamma init for k=9 (first 5): tensor([ 0.0002, -0.0007,  0.0005,  0.0003,  0.0009])

Calculating gamma for k=10:
Number of diseases in cluster: 11
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.2733, -0.2733, -0.2733, -0.2733, -0.2733])
Base value centered mean: 6.260780764932861e-07
Gamma init for k=10 (first 5): tensor([-4.3365e-04,  9.8768e-03, -6.8178e-03,  7.1836e-05, -1.3331e-03])

Calculating gamma for k=11:
Number of diseases in cluster: 83
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.6964, -13.5773])
Base value centered (first 5): tensor([-0.1039, -0.1039, -0.1039,  0.0152,  0.1343])
Base value centered mean: -1.198039967675868e-06
Gamma init for k=11 (first 5): tensor([-8.3130e-04,  8.9312e-05, -4.3690e-05,  1.9273e-03,  1.4553e-03])

Calculating gamma for k=12:
Number of diseases in cluster: 12
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -12.9919])
Base value centered (first 5): tensor([-0.2032, -0.2032, -0.2032, -0.2032,  0.6204])
Base value centered mean: 1.2567575140565168e-06
Gamma init for k=12 (first 5): tensor([-0.0024, -0.0004, -0.0011,  0.0040, -0.0005])

Calculating gamma for k=13:
Number of diseases in cluster: 5
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -11.8388])
Base value centered (first 5): tensor([-0.2856, -0.2856, -0.2856, -0.2856,  1.6912])
Base value centered mean: -2.600082325443509e-06
Gamma init for k=13 (first 5): tensor([ 0.0004, -0.0001, -0.0074, -0.0025,  0.0003])

Calculating gamma for k=14:
Number of diseases in cluster: 16
Base value (first 5): tensor([-13.8155, -13.8155, -12.5800, -11.9623, -12.5800])
Base value centered (first 5): tensor([-0.4280, -0.4280,  0.8074,  1.4252,  0.8074])
Base value centered mean: -2.0309892079239944e-06
Gamma init for k=14 (first 5): tensor([-0.0070,  0.0013, -0.0025,  0.0160,  0.0009])

Calculating gamma for k=15:
Number of diseases in cluster: 5
Base value (first 5): tensor([-13.8155,  -9.8620, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1162,  3.8372, -0.1162, -0.1162, -0.1162])
Base value centered mean: -2.9337811611185316e-06
Gamma init for k=15 (first 5): tensor([-0.0014,  0.0011, -0.0015,  0.0206, -0.0005])

Calculating gamma for k=16:
Number of diseases in cluster: 21
Base value (first 5): tensor([-13.3449, -13.8155, -13.3449, -13.3449, -12.4036])
Base value centered (first 5): tensor([ 0.1508, -0.3198,  0.1508,  0.1508,  1.0921])
Base value centered mean: -2.0833728342495306e-07
Gamma init for k=16 (first 5): tensor([-0.0029, -0.0009, -0.0005,  0.0095,  0.0007])

Calculating gamma for k=17:
Number of diseases in cluster: 10
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -11.8388, -12.8271])
Base value centered (first 5): tensor([-0.2673, -0.2673, -0.2673,  1.7094,  0.7211])
Base value centered mean: -1.5654569551770692e-06
Gamma init for k=17 (first 5): tensor([-0.0072,  0.0036, -0.0037,  0.0201,  0.0041])

Calculating gamma for k=18:
Number of diseases in cluster: 7
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.0942, -0.0942, -0.0942, -0.0942, -0.0942])
Base value centered mean: -9.127803082265018e-07
Gamma init for k=18 (first 5): tensor([-0.0011,  0.0009, -0.0009, -0.0003,  0.0008])

Calculating gamma for k=19:
Number of diseases in cluster: 28
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.4625])
Base value centered (first 5): tensor([-0.1242, -0.1242, -0.1242, -0.1242,  0.2288])
Base value centered mean: -5.96346296788397e-07
Gamma init for k=19 (first 5): tensor([ 6.4825e-03,  1.5252e-04, -1.1859e-04,  1.2878e-03,  3.7712e-05])
Initializing with 20 disease states + 1 healthy state
Initialization complete!

Calculating gamma for k=0:
Number of diseases in cluster: 16.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -11.9623, -13.8155])
Base value centered (first 5): tensor([-0.1935, -0.1935, -0.1935,  1.6597, -0.1935])
Base value centered mean: -2.2248955247050617e-06
Gamma init for k=0 (first 5): tensor([-7.6671e-06,  2.4498e-04, -2.5758e-03,  3.7284e-03,  3.1849e-02])

Calculating gamma for k=1:
Number of diseases in cluster: 21.0
Base value (first 5): tensor([-13.3449, -13.8155, -13.3449, -13.3449, -12.4036])
Base value centered (first 5): tensor([ 0.1508, -0.3198,  0.1508,  0.1508,  1.0921])
Base value centered mean: -2.0833728342495306e-07
Gamma init for k=1 (first 5): tensor([-0.0029, -0.0009, -0.0005,  0.0095,  0.0007])

Calculating gamma for k=2:
Number of diseases in cluster: 15.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.1566, -11.8388, -12.4977])
Base value centered (first 5): tensor([-0.3874, -0.3874,  0.2715,  1.5893,  0.9304])
Base value centered mean: 9.514063208371226e-07
Gamma init for k=2 (first 5): tensor([-0.0068,  0.0018, -0.0015,  0.0147,  0.0013])

Calculating gamma for k=3:
Number of diseases in cluster: 82.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.6950, -13.5744])
Base value centered (first 5): tensor([-0.1046, -0.1046, -0.1046,  0.0159,  0.1364])
Base value centered mean: -8.652273209008854e-08
Gamma init for k=3 (first 5): tensor([-8.6792e-04,  7.5334e-05, -4.9636e-05,  1.9657e-03,  1.4700e-03])

Calculating gamma for k=4:
Number of diseases in cluster: 5.0
Base value (first 5): tensor([-13.8155,  -9.8620, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1162,  3.8372, -0.1162, -0.1162, -0.1162])
Base value centered mean: -2.9337811611185316e-06
Gamma init for k=4 (first 5): tensor([-0.0014,  0.0011, -0.0015,  0.0206, -0.0005])

Calculating gamma for k=5:
Number of diseases in cluster: 7.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.7188, -0.7188, -0.7188, -0.7188, -0.7188])
Base value centered mean: -9.069723887478176e-07
Gamma init for k=5 (first 5): tensor([-0.0055, -0.0013,  0.0016,  0.0026,  0.0089])

Calculating gamma for k=6:
Number of diseases in cluster: 8.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1586, -0.1586, -0.1586, -0.1586, -0.1586])
Base value centered mean: 1.22929847634623e-07
Gamma init for k=6 (first 5): tensor([ 0.0008,  0.0023,  0.0010, -0.0006,  0.0002])

Calculating gamma for k=7:
Number of diseases in cluster: 22.0
Base value (first 5): tensor([-13.3663, -13.8155, -13.3663, -13.8155, -11.1200])
Base value centered (first 5): tensor([-0.1104, -0.5596, -0.1104, -0.5596,  2.1359])
Base value centered mean: 1.2441315675459919e-06
Gamma init for k=7 (first 5): tensor([-0.0037, -0.0014, -0.0044,  0.0273,  0.0056])

Calculating gamma for k=8:
Number of diseases in cluster: 28.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1171, -0.1171, -0.1171, -0.1171, -0.1171])
Base value centered mean: -1.825794470278197e-06
Gamma init for k=8 (first 5): tensor([ 6.5371e-03,  8.5117e-05, -7.7508e-05,  9.9897e-04,  9.3553e-05])

Calculating gamma for k=9:
Number of diseases in cluster: 12.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -12.9919])
Base value centered (first 5): tensor([-0.2032, -0.2032, -0.2032, -0.2032,  0.6204])
Base value centered mean: 1.2567575140565168e-06
Gamma init for k=9 (first 5): tensor([-0.0024, -0.0004, -0.0011,  0.0040, -0.0005])

Calculating gamma for k=10:
Number of diseases in cluster: 11.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.2733, -0.2733, -0.2733, -0.2733, -0.2733])
Base value centered mean: 6.260780764932861e-07
Gamma init for k=10 (first 5): tensor([-4.3365e-04,  9.8768e-03, -6.8178e-03,  7.1836e-05, -1.3331e-03])

Calculating gamma for k=11:
Number of diseases in cluster: 8.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1156, -0.1156, -0.1156, -0.1156, -0.1156])
Base value centered mean: -1.4170835811455618e-06
Gamma init for k=11 (first 5): tensor([-1.5859e-03,  8.5422e-05,  2.4592e-03,  1.5822e-03,  6.4423e-03])

Calculating gamma for k=12:
Number of diseases in cluster: 7.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.0942, -0.0942, -0.0942, -0.0942, -0.0942])
Base value centered mean: -9.127803082265018e-07
Gamma init for k=12 (first 5): tensor([-0.0011,  0.0009, -0.0009, -0.0003,  0.0008])

Calculating gamma for k=13:
Number of diseases in cluster: 13.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.7744, -12.2949])
Base value centered (first 5): tensor([-0.1767, -0.1767, -0.1767,  2.8644,  1.3439])
Base value centered mean: -5.641829261548992e-07
Gamma init for k=13 (first 5): tensor([ 0.0002, -0.0004,  0.0004,  0.0008,  0.0007])

Calculating gamma for k=14:
Number of diseases in cluster: 10.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -11.8388, -12.8271])
Base value centered (first 5): tensor([-0.2673, -0.2673, -0.2673,  1.7094,  0.7211])
Base value centered mean: -1.5654569551770692e-06
Gamma init for k=14 (first 5): tensor([-0.0072,  0.0036, -0.0037,  0.0201,  0.0041])

Calculating gamma for k=15:
Number of diseases in cluster: 5.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -11.8388])
Base value centered (first 5): tensor([-0.2856, -0.2856, -0.2856, -0.2856,  1.6912])
Base value centered mean: -2.600082325443509e-06
Gamma init for k=15 (first 5): tensor([ 0.0004, -0.0001, -0.0074, -0.0025,  0.0003])

Calculating gamma for k=16:
Number of diseases in cluster: 29.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -11.7706, -13.1339])
Base value centered (first 5): tensor([-0.2510, -0.2510, -0.2510,  1.7939,  0.4307])
Base value centered mean: -2.1326889054762432e-06
Gamma init for k=16 (first 5): tensor([-0.0018, -0.0001,  0.0033,  0.0055,  0.0044])

Calculating gamma for k=17:
Number of diseases in cluster: 17.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -10.3272, -12.0713])
Base value centered (first 5): tensor([-0.4237, -0.4237, -0.4237,  3.0646,  1.3204])
Base value centered mean: -1.1451069212853326e-06
Gamma init for k=17 (first 5): tensor([-0.0018,  0.0008, -0.0002,  0.0074,  0.0048])

Calculating gamma for k=18:
Number of diseases in cluster: 9.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.8155])
Base value centered (first 5): tensor([-0.1585, -0.1585, -0.1585, -0.1585, -0.1585])
Base value centered mean: -8.551031527304076e-08
Gamma init for k=18 (first 5): tensor([-0.0034,  0.0035, -0.0015,  0.0030, -0.0003])

Calculating gamma for k=19:
Number of diseases in cluster: 23.0
Base value (first 5): tensor([-13.8155, -13.8155, -13.8155, -13.8155, -13.3858])
Base value centered (first 5): tensor([-0.1585, -0.1585, -0.1585, -0.1585,  0.2712])
Base value centered mean: -1.9106221316178562e-06
Gamma init for k=19 (first 5): tensor([-0.0008,  0.0012, -0.0013,  0.0010,  0.0019])
Initializing with 20 disease states + 1 healthy state
Initialization complete!

Clusters match exactly: True
/Users/sarahurbut/aladynoulli2/pyScripts_forPublish/clust_huge_amp_vectorized.py:577: UserWarning: Tight layout not applied. tight_layout cannot make Axes height small enough to accommodate all Axes decorations.
  plt.tight_layout()
No description has been provided for this image