Skip to content

Commit 2a4f329

Browse files
authored
Merge pull request #22 from Multiomics-Analytics-Group/ecoli-19-metaabacofit-with-vmm
fixed metaABaCo.fit() with VMM prior
2 parents 5fc59fe + 11261a6 commit 2a4f329

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

src/abaco/ABaCo.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4418,11 +4418,22 @@ def fit(
44184418
adv_lr=1e-3,
44194419
):
44204420
# Define optimizer
4421+
if isinstance(self.vae.prior, MoCPPrior):
4422+
prior_params = self.vae.prior.parameters()
4423+
4424+
elif isinstance(self.vae.prior, VMMPrior):
4425+
prior_params = [self.vae.prior.u, self.vae.prior.var]
4426+
4427+
else:
4428+
raise NotImplementedError(
4429+
"metaABaCo prior distribution can only be 'MoG' or 'VMM'"
4430+
)
4431+
44214432
vae_optimizer_1 = torch.optim.Adam(
44224433
[
44234434
{"params": self.vae.encoder.parameters()},
44244435
{"params": self.vae.decoder.parameters()},
4425-
{"params": self.vae.prior.parameters()},
4436+
{"params": prior_params},
44264437
],
44274438
lr=phase_1_vae_lr,
44284439
)

0 commit comments

Comments
 (0)