Skip to content

Commit 1fb9777

Browse files
add beta distribution
1 parent 10d6710 commit 1fb9777

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

ngboost/distns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""NGBoost distributions"""
2+
from .beta import Beta
23
from .categorical import Bernoulli, k_categorical
34
from .cauchy import Cauchy
45
from .distn import ClassificationDistn, Distn, RegressionDistn
@@ -15,6 +16,7 @@
1516

1617
__all__ = [
1718
"Bernoulli",
19+
"Beta",
1820
"k_categorical",
1921
"Cauchy",
2022
"ClassificationDistn",

ngboost/distns/beta.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""The NGBoost Beta distribution and scores"""
2+
from scipy.stats import beta as dist
3+
from scipy.special import digamma, polygamma
4+
import numpy as np
5+
6+
from ngboost.distns.distn import RegressionDistn
7+
from ngboost.scores import LogScore
8+
9+
class BetaLogScore(LogScore):
10+
"""Log score for the Beta distribution."""
11+
def score(self, Y):
12+
"""Calculate the log score for the Beta distribution."""
13+
return -self.dist.logpdf(Y)
14+
15+
def d_score(self, Y):
16+
"""Calculate the derivative of the log score with respect to the parameters."""
17+
D = np.zeros((len(Y), 2)) # first col is dS/d(log(a)), second col is dS/d(log(b))
18+
D[:, 0] = - self.a * ( digamma(self.a + self.b) - digamma(self.a) + np.log(Y))
19+
D[:, 1] = - self.b * ( digamma(self.a + self.b) - digamma(self.b) + np.log(1 - Y))
20+
return D
21+
22+
def metric(self):
23+
"""Return the Fisher Information matrix for the Beta distribution."""
24+
FI = np.zeros((self.a.shape[0], 2, 2))
25+
trigamma_a_b = polygamma(1, self.a + self.b)
26+
FI[:, 0, 0] = self.a**2 * ( polygamma(1, self.a) - trigamma_a_b )
27+
FI[:, 0, 1] = -self.a * self.b * trigamma_a_b
28+
FI[:, 1, 0] = -self.a * self.b * trigamma_a_b
29+
FI[:, 1, 1] = self.b**2 * ( polygamma(1, self.b) - trigamma_a_b )
30+
return FI
31+
32+
class Beta(RegressionDistn):
33+
"""
34+
Implements the Beta distribution for NGBoost.
35+
36+
The Beta distribution has two parameters, a and b.
37+
The scipy loc and scale parameters are held constant for this implementation.
38+
LogScore is supported for the Beta distribution.
39+
"""
40+
41+
n_params = 2
42+
scores = [BetaLogScore] # will implement this later
43+
44+
def __init__(self, params):
45+
self._params = params
46+
47+
# create other objects that will be useful later
48+
self.log_a = params[0]
49+
self.log_b = params[1]
50+
self.a = np.exp(params[0]) # since params[0] is log(a)
51+
self.b = np.exp(params[1]) # since params[1] is log(b)
52+
self.dist = dist(a=self.a, b=self.b)
53+
54+
@staticmethod
55+
def fit(Y):
56+
"""Fit the distribution to the data."""
57+
# Use scipy's beta distribution to fit the parameters
58+
a, b, loc, scale = dist.fit(Y, floc=0, fscale=1)
59+
return np.array([np.log(a), np.log(b)])
60+
61+
def sample(self, m):
62+
"""Sample from the distribution."""
63+
return np.array([self.dist.rvs() for i in range(m)])
64+
65+
def __getattr__(self, name): # gives us access to Beta.mean() required for RegressionDist.predict()
66+
if name in dir(self.dist):
67+
return getattr(self.dist, name)
68+
return None
69+
70+
@property
71+
def params(self):
72+
"""Return the parameters of the Beta distribution."""
73+
return {'a': self.a, 'b': self.b}

tests/test_distns.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ngboost import NGBClassifier, NGBRegressor, NGBSurvival
1010
from ngboost.distns import (
1111
Bernoulli,
12+
Beta,
1213
Cauchy,
1314
Distn,
1415
Exponential,
@@ -166,6 +167,34 @@ def test_bernoulli(learner, classification_data: Tuple4Array):
166167
# test properties of output
167168

168169

170+
@pytest.mark.slow
171+
@pytest.mark.parametrize(
172+
"learner",
173+
[
174+
DecisionTreeRegressor(criterion="friedman_mse", max_depth=5),
175+
DecisionTreeRegressor(criterion="friedman_mse", max_depth=3),
176+
],
177+
)
178+
def test_beta(learner, regression_data: Tuple4Array):
179+
X_reg_train, X_reg_test, Y_reg_train, Y_reg_test = regression_data
180+
181+
# Scale the target to (0, 1) range for Beta distribution
182+
min_value = min(Y_reg_train.min(), Y_reg_test.min())
183+
max_value = max(Y_reg_train.max(), Y_reg_test.max())
184+
Y_reg_train = (Y_reg_train - min_value)/(max_value - min_value)
185+
Y_reg_train = np.clip(Y_reg_train, 1e-5, 1 - 1e-5) # Avoid log(0) issues
186+
Y_reg_test = (Y_reg_test - min_value)/(max_value - min_value)
187+
Y_reg_test = np.clip(Y_reg_test, 1e-5, 1 - 1e-5) # Avoid log(0) issues
188+
189+
# test early stopping features
190+
# test other args, n_trees, LR, minibatching- args as fixture
191+
ngb = NGBRegressor(Dist=Beta, Score=LogScore, Base=learner, verbose=False)
192+
ngb.fit(X_reg_train, Y_reg_train)
193+
y_pred = ngb.predict(X_reg_test)
194+
y_dist = ngb.pred_dist(X_reg_test)
195+
# test properties of output
196+
197+
169198
@pytest.mark.slow
170199
@pytest.mark.parametrize("k", [2, 4, 7])
171200
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)