@@ -769,7 +769,7 @@ def sample(self, sample_shape=torch.Size()):
769
769
dm_sample:
770
770
Sample(s) from the Dirichlet Multinomial distribution.
771
771
"""
772
- shape = self ._extended_shape (sample_shape )
772
+ # shape = self._extended_shape(sample_shape)
773
773
p = td .Dirichlet (self .concentration ).sample (sample_shape )
774
774
775
775
batch_dims = p .shape [:- 1 ]
@@ -1207,7 +1207,7 @@ def kl_div_loss(self, x):
1207
1207
KL-divergence loss
1208
1208
"""
1209
1209
q = self .encoder (x )
1210
- z = q .rsample ()
1210
+ # z = q.rsample()
1211
1211
kl_loss = torch .mean (
1212
1212
self .beta * td .kl_divergence (q , self .prior ()),
1213
1213
dim = 0 ,
@@ -1359,7 +1359,7 @@ def elbo(self, x):
1359
1359
1360
1360
def kl_div_loss (self , x ):
1361
1361
q = self .encoder (x )
1362
- z = q .rsample ()
1362
+ # z = q.rsample()
1363
1363
kl_loss = torch .mean (
1364
1364
self .beta * td .kl_divergence (q , self .prior ()),
1365
1365
dim = 0 ,
@@ -2035,7 +2035,7 @@ def train_abaco(
2035
2035
for loader_data in data_iter :
2036
2036
x = loader_data [0 ].to (device )
2037
2037
y = loader_data [1 ].to (device ).float () # Batch label
2038
- z = loader_data [2 ].to (device ).float () # Bio type label
2038
+ # z = loader_data[2].to(device).float() # Bio type label
2039
2039
2040
2040
# VAE ELBO computation with masked batch label
2041
2041
vae_optim_post .zero_grad ()
@@ -2050,8 +2050,8 @@ def train_abaco(
2050
2050
p_xz = vae .decoder (torch .cat ([latent_points , alpha * y ], dim = 1 ))
2051
2051
2052
2052
# Log probabilities of prior and posterior
2053
- log_q_zx = q_zx .log_prob (latent_points )
2054
- log_p_z = vae .log_prob (latent_points )
2053
+ # log_q_zx = q_zx.log_prob(latent_points)
2054
+ # log_p_z = vae.log_prob(latent_points)
2055
2055
2056
2056
# Compute ELBO
2057
2057
recon_term = p_xz .log_prob (x ).mean ()
@@ -2829,7 +2829,7 @@ def train_abaco_ensemble(
2829
2829
for loader_data in data_iter :
2830
2830
x = loader_data [0 ].to (device )
2831
2831
y = loader_data [1 ].to (device ).float () # Batch label
2832
- z = loader_data [2 ].to (device ).float () # Bio type label
2832
+ # z = loader_data[2].to(device).float() # Bio type label
2833
2833
2834
2834
# VAE ELBO computation with masked batch label
2835
2835
vae_optim_post .zero_grad ()
@@ -2849,8 +2849,8 @@ def train_abaco_ensemble(
2849
2849
p_xzs .append (p_xz )
2850
2850
2851
2851
# Log probabilities of prior and posterior
2852
- log_q_zx = q_zx .log_prob (latent_points )
2853
- log_p_z = vae .log_prob (latent_points )
2852
+ # log_q_zx = q_zx.log_prob(latent_points)
2853
+ # log_p_z = vae.log_prob(latent_points)
2854
2854
2855
2855
# Compute ELBO
2856
2856
@@ -4529,7 +4529,7 @@ def correct(
4529
4529
for loader_data in iter (self .dataloader ):
4530
4530
x = loader_data [0 ].to (self .device )
4531
4531
ohe_batch = loader_data [1 ].to (self .device ).float () # Batch label
4532
- ohe_bio = loader_data [2 ].to (self .device ).float () # Bio type label
4532
+ # ohe_bio = loader_data[2].to(self.device).float() # Bio type label
4533
4533
4534
4534
# Encode and decode the input data along with the one-hot encoded batch label
4535
4535
q_zx = self .vae .encoder (torch .cat ([x , ohe_batch ], dim = 1 )) # td.Distribution
0 commit comments