@@ -93,6 +93,8 @@ def compile_multi_model_trainer(self,
93
93
the number of neural networks is equal to the number of Gaussian
94
94
processes. For example, if the outputs are spectra of length 128,
95
95
one will have 128 neural networks and 128 GPs trained in parallel.
96
+ It can be also used for training an ensembles of models for the same
97
+ scalar output.
96
98
"""
97
99
if self .correlated_output :
98
100
raise NotImplementedError (
@@ -160,7 +162,8 @@ def compile_trainer(self, X: Union[torch.Tensor, np.ndarray],
160
162
161
163
Keyword Args:
162
164
feature_extractor:
163
- (Optional) Custom neural network for feature extractor
165
+ (Optional) Custom neural network for feature extractor.
166
+ Must take input/feature dims and embedding dims as its arguments.
164
167
grid_size:
165
168
Grid size for structured kernel interpolation (Default: 50)
166
169
freeze_weights:
@@ -174,9 +177,8 @@ def compile_trainer(self, X: Union[torch.Tensor, np.ndarray],
174
177
"use compile_multi_model_trainer(*args, **kwargs)" )
175
178
X , y = self .set_data (X , y )
176
179
input_dim , embedim = self .dimdict ["input_dim" ], self .dimdict ["embedim" ]
177
- feature_extractor = kwargs .get ("feature_extractor" )
178
- if feature_extractor is None :
179
- feature_extractor = fcFeatureExtractor (input_dim , embedim )
180
+ feature_net = kwargs .get ("feature_extractor" , fcFeatureExtractor )
181
+ feature_extractor = feature_net (input_dim , embedim )
180
182
freeze_weights = kwargs .get ("freeze_weights" , False )
181
183
if freeze_weights :
182
184
for p in feature_extractor .parameters ():
0 commit comments