Skip to content

Commit 4586efa

Browse files
authored
Merge pull request #56 from ziatdinovmax/master
Add ensemble DKL and notebook with example
2 parents 676b02c + d6c33e2 commit 4586efa

File tree

5 files changed

+2081
-12
lines changed

5 files changed

+2081
-12
lines changed

atomai/models/dklgp/dklgpr.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def fit(self, X: Union[torch.Tensor, np.ndarray],
8282
8383
Keyword Args:
8484
feature_extractor:
85-
(Optional) Custom neural network for feature extractor
85+
(Optional) Custom neural network for feature extractor.
86+
Must take input/feature dims and embedding dims as its arguments.
8687
freeze_weights:
8788
Freezes weights of feature extractor, that is, they are not
8889
passed to the optimizer. Used for a transfer learning.
@@ -91,6 +92,45 @@ def fit(self, X: Union[torch.Tensor, np.ndarray],
9192
"""
9293
_ = self.run(X, y, training_cycles, **kwargs)
9394

95+
def fit_ensemble(self, X: Union[torch.Tensor, np.ndarray],
96+
y: Union[torch.Tensor, np.ndarray],
97+
training_cycles: int = 1,
98+
n_models: int = 5,
99+
**kwargs: Union[Type[torch.nn.Module], bool, float]
100+
) -> None:
101+
"""
102+
Initializes and trains an ensemble of deep kernel GP model
103+
104+
Args:
105+
X: Input training data (aka features) of N x input_dim dimensions
106+
y: Output targets of batch_size x N or N (if batch_size=1) dimensions
107+
training_cycles: Number of training epochs
108+
n_models: Number of models in ensemble
109+
110+
Keyword Args:
111+
feature_extractor:
112+
(Optional) Custom neural network for feature extractor.
113+
Must take input/feature dims and embedding dims as its arguments.
114+
freeze_weights:
115+
Freezes weights of feature extractor, that is, they are not
116+
passed to the optimizer. Used for a transfer learning.
117+
lr: learning rate (Default: 0.01)
118+
print_loss: print loss at every n-th training cycle (epoch)
119+
"""
120+
if y.ndim == 1:
121+
y = y[None]
122+
if y.shape[0] > 1:
123+
raise NotImplementedError(
124+
"The ensemble training is currently supported only for scalar targets")
125+
y = y.repeat(n_models, 0) if isinstance(y, np.ndarray) else y.repeat(n_models, 1)
126+
if self.correlated_output:
127+
msg = ("Replacing shared independent embedding space with" +
128+
" {} independent ones").format(n_models)
129+
warnings.warn(msg)
130+
self.correlated_output = False
131+
self.ensemble = True
132+
_ = self.run(X, y, training_cycles, **kwargs)
133+
94134
def _compute_posterior(self, X: torch.Tensor) -> Union[mvn_, List[mvn_]]:
95135
"""
96136
Computes the posterior over model outputs at the provided points (X).
@@ -194,7 +234,7 @@ def embed(self, x_new: Union[torch.Tensor, np.ndarray],
194234
x_new, _ = self.set_data(x_new, device='cpu')
195235
data_loader = init_dataloader(x_new, shuffle=False, **kwargs)
196236
embeded = torch.cat([self._embed(x.to(self.device)) for (x,) in data_loader], 0)
197-
if not self.correlated_output:
237+
if not self.correlated_output and not self.ensemble:
198238
embeded = embeded.permute(-1, 0, 1)
199239
return embeded.numpy()
200240

atomai/trainers/gptrainer.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self,
5252
self.correlated_output = shared_embedding_space
5353
self.gp_model = None
5454
self.likelihood = None
55+
self.ensemble = False
5556
self.compiled = False
5657
self.train_loss = []
5758

@@ -92,6 +93,8 @@ def compile_multi_model_trainer(self,
9293
the number of neural networks is equal to the number of Gaussian
9394
processes. For example, if the outputs are spectra of length 128,
9495
one will have 128 neural networks and 128 GPs trained in parallel.
96+
It can be also used for training an ensembles of models for the same
97+
scalar output.
9598
"""
9699
if self.correlated_output:
97100
raise NotImplementedError(
@@ -101,16 +104,21 @@ def compile_multi_model_trainer(self,
101104
if y.shape[0] < 2:
102105
raise ValueError("The training targets must be vector-valued (d >1)")
103106
input_dim, embedim = self.dimdict["input_dim"], self.dimdict["embedim"]
104-
feature_extractor = kwargs.get("feature_extractor")
105-
if feature_extractor is None:
106-
feature_extractor = fcFeatureExtractor(input_dim, embedim)
107+
feature_net = kwargs.get("feature_extractor", fcFeatureExtractor)
107108
freeze_weights = kwargs.get("freeze_weights", False)
108-
if freeze_weights:
109-
for p in feature_extractor.parameters():
110-
p.requires_grad = False
109+
if not self.ensemble:
110+
feature_extractor = feature_net(input_dim, embedim)
111+
if freeze_weights:
112+
for p in feature_extractor.parameters():
113+
p.requires_grad = False
111114
list_of_models = []
112115
list_of_likelihoods = []
113116
for i in range(y.shape[0]):
117+
if self.ensemble: # different initilization for each model
118+
feature_extractor = feature_net(input_dim, embedim)
119+
if freeze_weights:
120+
for p in feature_extractor.parameters():
121+
p.requires_grad = False
114122
model_i = GPRegressionModel(
115123
X, y[i:i+1],
116124
gpytorch.likelihoods.GaussianLikelihood(batch_shape=torch.Size([1])),
@@ -154,7 +162,8 @@ def compile_trainer(self, X: Union[torch.Tensor, np.ndarray],
154162
155163
Keyword Args:
156164
feature_extractor:
157-
(Optional) Custom neural network for feature extractor
165+
(Optional) Custom neural network for feature extractor.
166+
Must take input/feature dims and embedding dims as its arguments.
158167
grid_size:
159168
Grid size for structured kernel interpolation (Default: 50)
160169
freeze_weights:
@@ -168,9 +177,8 @@ def compile_trainer(self, X: Union[torch.Tensor, np.ndarray],
168177
"use compile_multi_model_trainer(*args, **kwargs)")
169178
X, y = self.set_data(X, y)
170179
input_dim, embedim = self.dimdict["input_dim"], self.dimdict["embedim"]
171-
feature_extractor = kwargs.get("feature_extractor")
172-
if feature_extractor is None:
173-
feature_extractor = fcFeatureExtractor(input_dim, embedim)
180+
feature_net = kwargs.get("feature_extractor", fcFeatureExtractor)
181+
feature_extractor = feature_net(input_dim, embedim)
174182
freeze_weights = kwargs.get("freeze_weights", False)
175183
if freeze_weights:
176184
for p in feature_extractor.parameters():

0 commit comments

Comments
 (0)