Skip to content

Commit 5d484ee

Browse files
lint: linting
1 parent 1fb9777 commit 5d484ee

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

ngboost/distns/beta.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,39 @@
11
"""The NGBoost Beta distribution and scores"""
2-
from scipy.stats import beta as dist
3-
from scipy.special import digamma, polygamma
42
import numpy as np
3+
from scipy.special import digamma, polygamma
4+
from scipy.stats import beta as dist
55

66
from ngboost.distns.distn import RegressionDistn
77
from ngboost.scores import LogScore
88

9+
910
class BetaLogScore(LogScore):
1011
"""Log score for the Beta distribution."""
12+
1113
def score(self, Y):
1214
"""Calculate the log score for the Beta distribution."""
1315
return -self.dist.logpdf(Y)
1416

1517
def d_score(self, Y):
1618
"""Calculate the derivative of the log score with respect to the parameters."""
17-
D = np.zeros((len(Y), 2)) # first col is dS/d(log(a)), second col is dS/d(log(b))
18-
D[:, 0] = - self.a * ( digamma(self.a + self.b) - digamma(self.a) + np.log(Y))
19-
D[:, 1] = - self.b * ( digamma(self.a + self.b) - digamma(self.b) + np.log(1 - Y))
19+
D = np.zeros(
20+
(len(Y), 2)
21+
) # first col is dS/d(log(a)), second col is dS/d(log(b))
22+
D[:, 0] = -self.a * (digamma(self.a + self.b) - digamma(self.a) + np.log(Y))
23+
D[:, 1] = -self.b * (digamma(self.a + self.b) - digamma(self.b) + np.log(1 - Y))
2024
return D
2125

2226
def metric(self):
2327
"""Return the Fisher Information matrix for the Beta distribution."""
2428
FI = np.zeros((self.a.shape[0], 2, 2))
2529
trigamma_a_b = polygamma(1, self.a + self.b)
26-
FI[:, 0, 0] = self.a**2 * ( polygamma(1, self.a) - trigamma_a_b )
30+
FI[:, 0, 0] = self.a**2 * (polygamma(1, self.a) - trigamma_a_b)
2731
FI[:, 0, 1] = -self.a * self.b * trigamma_a_b
2832
FI[:, 1, 0] = -self.a * self.b * trigamma_a_b
29-
FI[:, 1, 1] = self.b**2 * ( polygamma(1, self.b) - trigamma_a_b )
33+
FI[:, 1, 1] = self.b**2 * (polygamma(1, self.b) - trigamma_a_b)
3034
return FI
3135

36+
3237
class Beta(RegressionDistn):
3338
"""
3439
Implements the Beta distribution for NGBoost.
@@ -39,35 +44,39 @@ class Beta(RegressionDistn):
3944
"""
4045

4146
n_params = 2
42-
scores = [BetaLogScore] # will implement this later
47+
scores = [BetaLogScore] # will implement this later
4348

49+
# pylint: disable=super-init-not-called
4450
def __init__(self, params):
4551
self._params = params
4652

4753
# create other objects that will be useful later
4854
self.log_a = params[0]
4955
self.log_b = params[1]
50-
self.a = np.exp(params[0]) # since params[0] is log(a)
51-
self.b = np.exp(params[1]) # since params[1] is log(b)
56+
self.a = np.exp(params[0]) # since params[0] is log(a)
57+
self.b = np.exp(params[1]) # since params[1] is log(b)
5258
self.dist = dist(a=self.a, b=self.b)
5359

5460
@staticmethod
5561
def fit(Y):
5662
"""Fit the distribution to the data."""
5763
# Use scipy's beta distribution to fit the parameters
64+
# pylint: disable=unused-variable
5865
a, b, loc, scale = dist.fit(Y, floc=0, fscale=1)
5966
return np.array([np.log(a), np.log(b)])
6067

6168
def sample(self, m):
6269
"""Sample from the distribution."""
6370
return np.array([self.dist.rvs() for i in range(m)])
6471

65-
def __getattr__(self, name): # gives us access to Beta.mean() required for RegressionDist.predict()
72+
def __getattr__(
73+
self, name
74+
): # gives us access to Beta.mean() required for RegressionDist.predict()
6675
if name in dir(self.dist):
6776
return getattr(self.dist, name)
6877
return None
6978

7079
@property
7180
def params(self):
7281
"""Return the parameters of the Beta distribution."""
73-
return {'a': self.a, 'b': self.b}
82+
return {"a": self.a, "b": self.b}

tests/test_distns.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ def test_beta(learner, regression_data: Tuple4Array):
181181
# Scale the target to (0, 1) range for Beta distribution
182182
min_value = min(Y_reg_train.min(), Y_reg_test.min())
183183
max_value = max(Y_reg_train.max(), Y_reg_test.max())
184-
Y_reg_train = (Y_reg_train - min_value)/(max_value - min_value)
184+
Y_reg_train = (Y_reg_train - min_value) / (max_value - min_value)
185185
Y_reg_train = np.clip(Y_reg_train, 1e-5, 1 - 1e-5) # Avoid log(0) issues
186-
Y_reg_test = (Y_reg_test - min_value)/(max_value - min_value)
186+
Y_reg_test = (Y_reg_test - min_value) / (max_value - min_value)
187187
Y_reg_test = np.clip(Y_reg_test, 1e-5, 1 - 1e-5) # Avoid log(0) issues
188188

189189
# test early stopping features

0 commit comments

Comments
 (0)