Skip to content

Commit 5aa59fb

Browse files
committed
lint
1 parent 8f89738 commit 5aa59fb

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

pyro/distributions/sine_bivariate_von_mises.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ class SineBivariateVonMises(TorchDistribution):
4545
4646
\frac{\rho^2}{\kappa_1\kappa_2} \rightarrow 1
4747
48-
because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the
49-
`weighted_correlation` parameter with a skew away from one (e.g.,
50-
`TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation`
48+
because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the
49+
`weighted_correlation` parameter with a skew away from one (e.g.,
50+
`TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation`
5151
should be in [-1,1].
5252
5353
.. note:: The correlation and weighted_correlation params are mutually exclusive.
@@ -141,7 +141,8 @@ def norm_const(self):
141141
- m * torch.log(4 * torch.prod(conc, dim=-1))
142142
)
143143
num_I1terms = torch.maximum(
144-
torch.tensor(501), torch.max(self.phi_concentration) + torch.max(self.psi_concentration)
144+
torch.tensor(501),
145+
torch.max(self.phi_concentration) + torch.max(self.psi_concentration),
145146
).int()
146147

147148
fs += log_I1(m.max(), conc, num_I1terms).sum(-1)

tests/distributions/test_sine_bivariate_von_mises.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,15 @@ def guide(data):
131131

132132
assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2)
133133

134+
134135
@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0])
135136
def test_sine_bivariate_von_mises_norm(conc):
136137
dist = SineBivariateVonMises(0, 0, conc, conc, 0.0)
137138
num_samples = 500
138139
x = torch.linspace(-torch.pi, torch.pi, num_samples)
139140
y = torch.linspace(-torch.pi, torch.pi, num_samples)
140-
mesh = torch.stack(torch.meshgrid(x, y, indexing='ij'), axis=-1)
141+
mesh = torch.stack(torch.meshgrid(x, y, indexing="ij"), axis=-1)
141142
integral_torus = (
142143
torch.exp(dist.log_prob(mesh)) * (2 * torch.pi) ** 2 / num_samples**2
143144
).sum()
144-
assert torch.allclose(integral_torus, torch.tensor(1.0), rtol=1e-2)
145+
assert torch.allclose(integral_torus, torch.tensor(1.0), rtol=1e-2)

0 commit comments

Comments
 (0)