Skip to content

Commit d6c33e2

Browse files
committed
Custom NNs must take feature & embed dims as args
1 parent add0a41 commit d6c33e2

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

atomai/models/dklgp/dklgpr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def fit(self, X: Union[torch.Tensor, np.ndarray],
8282
8383
Keyword Args:
8484
feature_extractor:
85-
(Optional) Custom neural network for feature extractor
85+
(Optional) Custom neural network for feature extractor.
86+
Must take input/feature dims and embedding dims as its arguments.
8687
freeze_weights:
8788
Freezes weights of feature extractor, that is, they are not
8889
passed to the optimizer. Used for a transfer learning.
@@ -98,7 +99,7 @@ def fit_ensemble(self, X: Union[torch.Tensor, np.ndarray],
9899
**kwargs: Union[Type[torch.nn.Module], bool, float]
99100
) -> None:
100101
"""
101-
Initializes and trains a deep kernel GP model
102+
Initializes and trains an ensemble of deep kernel GP model
102103
103104
Args:
104105
X: Input training data (aka features) of N x input_dim dimensions
@@ -108,7 +109,8 @@ def fit_ensemble(self, X: Union[torch.Tensor, np.ndarray],
108109
109110
Keyword Args:
110111
feature_extractor:
111-
(Optional) Custom neural network for feature extractor
112+
(Optional) Custom neural network for feature extractor.
113+
Must take input/feature dims and embedding dims as its arguments.
112114
freeze_weights:
113115
Freezes weights of feature extractor, that is, they are not
114116
passed to the optimizer. Used for a transfer learning.

atomai/trainers/gptrainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def compile_multi_model_trainer(self,
9393
the number of neural networks is equal to the number of Gaussian
9494
processes. For example, if the outputs are spectra of length 128,
9595
one will have 128 neural networks and 128 GPs trained in parallel.
96+
It can be also used for training an ensembles of models for the same
97+
scalar output.
9698
"""
9799
if self.correlated_output:
98100
raise NotImplementedError(
@@ -160,7 +162,8 @@ def compile_trainer(self, X: Union[torch.Tensor, np.ndarray],
160162
161163
Keyword Args:
162164
feature_extractor:
163-
(Optional) Custom neural network for feature extractor
165+
(Optional) Custom neural network for feature extractor.
166+
Must take input/feature dims and embedding dims as its arguments.
164167
grid_size:
165168
Grid size for structured kernel interpolation (Default: 50)
166169
freeze_weights:
@@ -174,9 +177,8 @@ def compile_trainer(self, X: Union[torch.Tensor, np.ndarray],
174177
"use compile_multi_model_trainer(*args, **kwargs)")
175178
X, y = self.set_data(X, y)
176179
input_dim, embedim = self.dimdict["input_dim"], self.dimdict["embedim"]
177-
feature_extractor = kwargs.get("feature_extractor")
178-
if feature_extractor is None:
179-
feature_extractor = fcFeatureExtractor(input_dim, embedim)
180+
feature_net = kwargs.get("feature_extractor", fcFeatureExtractor)
181+
feature_extractor = feature_net(input_dim, embedim)
180182
freeze_weights = kwargs.get("freeze_weights", False)
181183
if freeze_weights:
182184
for p in feature_extractor.parameters():

0 commit comments

Comments
 (0)