@@ -149,12 +149,14 @@ def __init__(
149
149
):
150
150
super ().__init__ (posterior , inducing_inputs , jitter )
151
151
152
- self .variational_mean = Real (
153
- variational_mean or jnp .zeros ((self .num_inducing , 1 ))
154
- )
155
- self .variational_root_covariance = LowerTriangular (
156
- variational_root_covariance or jnp .eye (self .num_inducing )
157
- )
152
+ if variational_mean is None :
153
+ variational_mean = jnp .zeros ((self .num_inducing , 1 ))
154
+
155
+ if variational_root_covariance is None :
156
+ variational_root_covariance = jnp .eye (self .num_inducing )
157
+
158
+ self .variational_mean = Real (variational_mean )
159
+ self .variational_root_covariance = LowerTriangular (variational_root_covariance )
158
160
159
161
def prior_kl (self ) -> ScalarFloat :
160
162
r"""Compute the prior KL divergence.
@@ -378,12 +380,14 @@ def __init__(
378
380
):
379
381
super ().__init__ (posterior , inducing_inputs , jitter )
380
382
381
- self .natural_vector = Static (
382
- natural_vector or jnp .zeros ((self .num_inducing , 1 ))
383
- )
384
- self .natural_matrix = Static (
385
- natural_matrix or - 0.5 * jnp .eye (self .num_inducing )
386
- )
383
+ if natural_vector is None :
384
+ natural_vector = jnp .zeros ((self .num_inducing , 1 ))
385
+
386
+ if natural_matrix is None :
387
+ natural_matrix = - 0.5 * jnp .eye (self .num_inducing )
388
+
389
+ self .natural_vector = Static (natural_vector )
390
+ self .natural_matrix = Static (natural_matrix )
387
391
388
392
def prior_kl (self ) -> ScalarFloat :
389
393
r"""Compute the KL-divergence between our current variational approximation
@@ -540,13 +544,14 @@ def __init__(
540
544
):
541
545
super ().__init__ (posterior , inducing_inputs , jitter )
542
546
543
- # must come after super().__init__
544
- self .expectation_vector = Static (
545
- expectation_vector or jnp .zeros ((self .num_inducing , 1 ))
546
- )
547
- self .expectation_matrix = Static (
548
- expectation_matrix or jnp .eye (self .num_inducing )
549
- )
547
+ if expectation_vector is None :
548
+ expectation_vector = jnp .zeros ((self .num_inducing , 1 ))
549
+
550
+ if expectation_matrix is None :
551
+ expectation_matrix = jnp .eye (self .num_inducing )
552
+
553
+ self .expectation_vector = Static (expectation_vector )
554
+ self .expectation_matrix = Static (expectation_matrix )
550
555
551
556
def prior_kl (self ) -> ScalarFloat :
552
557
r"""Evaluate the prior KL-divergence.
0 commit comments