Skip to content

Commit b75afea

Browse files
authored
Merge pull request #92 from nlesc-nano/dev
Minor clean
2 parents 8615ec8 + 192d5d4 commit b75afea

File tree

10 files changed

+61
-35
lines changed

10 files changed

+61
-35
lines changed

scripts/predict_gp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
researcher.load_model("swan_chk.pt")
3131

3232
fingers = data.fingerprints
33-
print("shape fingers: ", fingers.shape)
3433
predicted = researcher.predict(fingers)
3534
df = pd.DataFrame(
3635
{"smiles": data.dataframe.smiles, "mean": predicted.mean, "lower": predicted.lower, "upper": predicted.upper})

swan/modeller/base_modeller.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,25 @@ def load_model(self, path_model: Optional[PathLike]) -> None:
6767
def save_model(self, *args, **kwargs):
6868
"""Store the trained model."""
6969
raise NotImplementedError
70+
71+
def store_trainset_in_state(self, indices: T_co, ntrain: int, store_features: bool = True) -> None:
72+
"""Store features, indices, smiles, etc. into the state file."""
73+
self.state.store_array("indices", indices, "int")
74+
self.state.store_array("ntrain", ntrain, "int")
75+
self.state.store_array("smiles_train", self.smiles[indices[:ntrain]], dtype="str")
76+
self.state.store_array("smiles_validate", self.smiles[indices[ntrain:]], dtype="str")
77+
78+
if isinstance(self.labels_trainset, torch.Tensor):
79+
self.state.store_array("labels_trainset", self.labels_trainset.numpy())
80+
self.state.store_array("labels_validset", self.labels_validset.numpy())
81+
else:
82+
self.state.store_array("labels_trainset", self.labels_trainset)
83+
self.state.store_array("labels_validset", self.labels_validset)
84+
85+
if store_features:
86+
if isinstance(self.features_trainset, torch.Tensor):
87+
self.state.store_array("features_trainset", self.features_trainset.numpy())
88+
self.state.store_array("features_validset", self.features_validset.numpy())
89+
else:
90+
self.state.store_array("features_trainset", self.features_trainset)
91+
self.state.store_array("features_validset", self.features_validset)

swan/modeller/gp_modeller.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,7 @@ def split_data(self, partition: SplitDataset) -> None:
6868
self.labels_validset = partition.labels_validset
6969
warnings.warn("The labels have not been scaled. Is this the intended behavior?", UserWarning)
7070

71-
indices = partition.indices
72-
ntrain = partition.ntrain
73-
self.state.store_array("smiles_train", self.smiles[indices[:ntrain]], dtype="str")
74-
self.state.store_array("smiles_validate", self.smiles[indices[ntrain:]], dtype="str")
75-
self.state.store_array("features_trainset", self.features_trainset.numpy())
76-
self.state.store_array("features_validset", self.features_validset.numpy())
77-
self.state.store_array("labels_trainset", self.labels_trainset.numpy())
78-
self.state.store_array("labels_validset", self.labels_validset.numpy())
79-
self.state.store_array("indices", indices, "int")
80-
self.state.store_array("ntrain", ntrain, "int")
71+
self.store_trainset_in_state(partition.indices, partition.ntrain)
8172

8273
def train_model(self,
8374
nepoch: int,
@@ -174,8 +165,6 @@ def predict(self, inp_data: Tensor) -> GPMultivariate:
174165
self.network.likelihood.eval()
175166

176167
with torch.no_grad(), gp.settings.fast_pred_var():
177-
first = self.network(inp_data)
178-
print(first.mean)
179168
output = self.network.likelihood(self.network(inp_data))
180169
return self._create_result_object(output)
181170

swan/modeller/scikit_modeller.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,7 @@ def split_data(self, frac: Tuple[float, float]) -> None:
6969
self.labels_validset = partition.labels_validset
7070

7171
# Split the smiles using the same partition than the features
72-
indices = partition.indices
73-
ntrain = partition.ntrain
74-
self.state.store_array("smiles_train", self.smiles[indices[:ntrain]], dtype="str")
75-
self.state.store_array("smiles_validate", self.smiles[indices[ntrain:]], dtype="str")
72+
self.store_trainset_in_state(partition.indices, partition.ntrain)
7673

7774
def save_model(self):
7875
"""Store the trained model."""
@@ -105,7 +102,7 @@ def predict(self, inp_data: np.ndarray) -> np.ndarray:
105102
-------
106103
Array containing the predicted results
107104
"""
108-
return self.model.predict(inp_data)
105+
return self.inverse_transform(self.model.predict(inp_data))
109106

110107
def inverse_transform(self, arr: np.ndarray) -> np.ndarray:
111108
"""Unscale ``arr`` using the fitted scaler.

swan/modeller/torch_modeller.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,9 @@ def split_data(self, frac: Tuple[float, float], batch_size: int):
121121
"""
122122
# create the dataloader
123123
indices_train, indices_validate = self.data.create_data_loader(frac=frac, batch_size=batch_size)
124-
125-
# Store the smiles used for training and validation
126-
self.state.store_array("smiles_train", self.smiles[indices_train], dtype="str")
127-
self.state.store_array("smiles_validate", self.smiles[indices_validate], dtype="str")
124+
self.labels_trainset = self.data.labels[indices_train]
125+
self.labels_validset = self.data.labels[indices_validate]
126+
self.store_trainset_in_state(np.concatenate((indices_train, indices_validate)), len(indices_validate), store_features=False)
128127

129128
def train_model(self,
130129
nepoch: int,

swan/state/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Module to interact with HDF5."""
22

33
from pathlib import Path
4-
from typing import Any, Optional
4+
from typing import Any, List, Optional, Union
55

66
import h5py
77
import numpy as np
@@ -22,7 +22,7 @@ def __init__(self, path_hdf5: Optional[PathLike] = None, replace_state: bool = F
2222
if not self.path.exists():
2323
self.path.touch()
2424

25-
def has_data(self, data: ArrayLike) -> bool:
25+
def has_data(self, data: Union[str, List[str]]) -> bool:
2626
"""Search if the node exists in the HDF5 file.
2727
2828
Parameters

tests/test_mpnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def setUp(self):
1818
self.data = PATH_TEST / "thousand.csv"
1919
self.data = TorchGeometricGraphData(self.data, properties=["Hardness (eta)"])
2020
self.net = MPNN()
21-
self.modeller = TorchModeller(self.net, self.data)
21+
self.modeller = TorchModeller(self.net, self.data, replace_state=True)
2222

2323
def test_train_data_mpnn(self):
2424

tests/test_scikit_models.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,45 @@
77

88
from .utils_test import PATH_TEST
99

10-
DATA = FingerprintsData(PATH_TEST / "thousand.csv", properties=["Hardness (eta)"], sanitize=False)
11-
DATA.scale_labels()
12-
1310

1411
def run_test(model: str, **kwargs):
1512
"""Run the training and validation step for the given model."""
16-
modeller = SKModeller(model, DATA)
13+
data = FingerprintsData(PATH_TEST / "thousand.csv", properties=["Hardness (eta)"], sanitize=False)
14+
data.scale_labels()
15+
modeller = SKModeller(model, data)
1716
modeller.train_model()
1817
predicted, expected = modeller.validate_model()
1918
reg = stats.linregress(predicted.flatten(), expected.flatten())
2019
assert not np.isnan(reg.rvalue)
2120

2221

22+
def run_prediction(model: str):
23+
"""Check the prediction functionality."""
24+
data = FingerprintsData(PATH_TEST / "smiles.csv", sanitize=False)
25+
modeller = SKModeller(model, data)
26+
modeller.load_model("swan_skmodeller.pkl")
27+
modeller.data.load_scale()
28+
predicted = modeller.predict(data.fingerprints)
29+
assert not np.isnan(predicted).all()
30+
31+
2332
def test_decision_tree():
2433
"""Check the interface to the Decisiontree class."""
25-
run_test("decision_tree")
34+
model = "decision_tree"
35+
run_test(model)
36+
run_prediction(model)
2637

2738

2839
def test_svm():
2940
"""Check the interface to the support vector machine."""
30-
run_test("svm")
41+
model = "svm"
42+
run_test(model)
43+
run_prediction(model)
3144

3245

3346
def test_gaussian_process():
3447
"""Check the interface to the support vector machine."""
3548
kernel = ConstantKernel(constant_value=10)
36-
run_test("gaussian_process", kernel=kernel)
49+
model = "gaussian_process"
50+
run_test(model, kernel=kernel)
51+
run_prediction(model)

tests/test_se3_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
torch.set_default_dtype(torch.float32)
1616

1717
CSV_FILE = PATH_TEST / "thousand.csv"
18-
DATA = DGLGraphData(CSV_FILE, properties=["Hardness (eta)"])
18+
DATA = DGLGraphData(CSV_FILE, properties=["Hardness (eta)"], optimize_molecule=True)
1919

2020

2121
def run_modeller(net: torch.nn.Module):
2222
"""Run a given model."""
23-
modeller = TorchModeller(net, DATA, use_cuda=False, replace_state=False)
23+
modeller = TorchModeller(net, DATA, use_cuda=False, replace_state=True)
2424

2525
modeller.data.scale_labels()
2626
modeller.train_model(nepoch=1, batch_size=64)

tests/test_state.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ def test_state(tmp_path: Path, capsys):
2525
out, _ = capsys.readouterr()
2626
assert "Available data" in out
2727

28+
assert not all(state.has_data(f"non_existing_{i}") for i in range(2))
29+
2830

2931
def test_state_unknown_key(tmp_path: Path):
3032
"""Check that an error is raised if there is not data."""
3133
path_hdf5 = tmp_path / "swan_state.h5"
32-
state = StateH5(path_hdf5)
34+
path_hdf5.touch()
35+
state = StateH5(path_hdf5, replace_state=True)
3336

3437
with pytest.raises(KeyError):
3538
state.retrieve_data("nonexisting property")
@@ -46,3 +49,5 @@ def store_smiles_in_state(tmp_path: Path):
4649
state.store_array("smiles", smiles, "str")
4750
data = [x.decode() for x in state.retrieve_data("smiles")]
4851
assert data == smiles.tolist()
52+
53+

0 commit comments

Comments
 (0)