We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 5fc59fe + 11261a6 commit 2a4f329Copy full SHA for 2a4f329
src/abaco/ABaCo.py
@@ -4418,11 +4418,22 @@ def fit(
4418
adv_lr=1e-3,
4419
):
4420
# Define optimizer
4421
+ if isinstance(self.vae.prior, MoCPPrior):
4422
+ prior_params = self.vae.prior.parameters()
4423
+
4424
+ elif isinstance(self.vae.prior, VMMPrior):
4425
+ prior_params = [self.vae.prior.u, self.vae.prior.var]
4426
4427
+ else:
4428
+ raise NotImplementedError(
4429
+ "metaABaCo prior distribution can only be 'MoG' or 'VMM'"
4430
+ )
4431
4432
vae_optimizer_1 = torch.optim.Adam(
4433
[
4434
{"params": self.vae.encoder.parameters()},
4435
{"params": self.vae.decoder.parameters()},
- {"params": self.vae.prior.parameters()},
4436
+ {"params": prior_params},
4437
],
4438
lr=phase_1_vae_lr,
4439
)
0 commit comments