Skip to content

Commit be2540f

Browse files
Fix optional array arguments in class constructors (#488)
1 parent d92f222 commit be2540f

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

gpjax/gps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,8 @@ def __init__(
652652
"""
653653
super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
654654

655-
latent = latent or jr.normal(key, shape=(self.likelihood.num_datapoints, 1))
655+
if latent is None:
656+
latent = jr.normal(key, shape=(self.likelihood.num_datapoints, 1))
656657

657658
# TODO: static or intermediate?
658659
self.latent = latent if isinstance(latent, Parameter) else Real(latent)

gpjax/variational_families.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,14 @@ def __init__(
149149
):
150150
super().__init__(posterior, inducing_inputs, jitter)
151151

152-
self.variational_mean = Real(
153-
variational_mean or jnp.zeros((self.num_inducing, 1))
154-
)
155-
self.variational_root_covariance = LowerTriangular(
156-
variational_root_covariance or jnp.eye(self.num_inducing)
157-
)
152+
if variational_mean is None:
153+
variational_mean = jnp.zeros((self.num_inducing, 1))
154+
155+
if variational_root_covariance is None:
156+
variational_root_covariance = jnp.eye(self.num_inducing)
157+
158+
self.variational_mean = Real(variational_mean)
159+
self.variational_root_covariance = LowerTriangular(variational_root_covariance)
158160

159161
def prior_kl(self) -> ScalarFloat:
160162
r"""Compute the prior KL divergence.
@@ -378,12 +380,14 @@ def __init__(
378380
):
379381
super().__init__(posterior, inducing_inputs, jitter)
380382

381-
self.natural_vector = Static(
382-
natural_vector or jnp.zeros((self.num_inducing, 1))
383-
)
384-
self.natural_matrix = Static(
385-
natural_matrix or -0.5 * jnp.eye(self.num_inducing)
386-
)
383+
if natural_vector is None:
384+
natural_vector = jnp.zeros((self.num_inducing, 1))
385+
386+
if natural_matrix is None:
387+
natural_matrix = -0.5 * jnp.eye(self.num_inducing)
388+
389+
self.natural_vector = Static(natural_vector)
390+
self.natural_matrix = Static(natural_matrix)
387391

388392
def prior_kl(self) -> ScalarFloat:
389393
r"""Compute the KL-divergence between our current variational approximation
@@ -540,13 +544,14 @@ def __init__(
540544
):
541545
super().__init__(posterior, inducing_inputs, jitter)
542546

543-
# must come after super().__init__
544-
self.expectation_vector = Static(
545-
expectation_vector or jnp.zeros((self.num_inducing, 1))
546-
)
547-
self.expectation_matrix = Static(
548-
expectation_matrix or jnp.eye(self.num_inducing)
549-
)
547+
if expectation_vector is None:
548+
expectation_vector = jnp.zeros((self.num_inducing, 1))
549+
550+
if expectation_matrix is None:
551+
expectation_matrix = jnp.eye(self.num_inducing)
552+
553+
self.expectation_vector = Static(expectation_vector)
554+
self.expectation_matrix = Static(expectation_matrix)
550555

551556
def prior_kl(self) -> ScalarFloat:
552557
r"""Evaluate the prior KL-divergence.

0 commit comments

Comments
 (0)