Skip to content

Commit 8615ec8

Browse files
committed
add script to predict GP #90
1 parent 98ff8b6 commit 8615ec8

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

scripts/predict_gp.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/usr/bin/env python
2+
3+
import logging
4+
from pathlib import Path
5+
6+
import numpy as np
7+
import pandas as pd
8+
import torch
9+
10+
from swan.dataset import FingerprintsData, load_split_dataset
11+
from swan.modeller import GPModeller
12+
from swan.modeller.models import GaussianProcess
13+
from swan.utils.log_config import configure_logger
14+
15+
# Starting logger
16+
configure_logger(Path("."))
17+
LOGGER = logging.getLogger(__name__)
18+
19+
# Set float size default
20+
torch.set_default_dtype(torch.float32)
21+
22+
partition = load_split_dataset()
23+
features, labels = [torch.from_numpy(getattr(partition, x).astype(np.float32)) for x in ("features_trainset", "labels_trainset")]
24+
model = GaussianProcess(features, labels.flatten())
25+
data = FingerprintsData(Path("tests/files/smiles.csv"), properties=None, sanitize=False)
26+
27+
researcher = GPModeller(model, data, use_cuda=False, replace_state=False)
28+
# # If the labels are scaled you need to load the scaling functionality
29+
# researcher.data.load_scale()
30+
researcher.load_model("swan_chk.pt")
31+
32+
fingers = data.fingerprints
33+
print("shape fingers: ", fingers.shape)
34+
predicted = researcher.predict(fingers)
35+
df = pd.DataFrame(
36+
{"smiles": data.dataframe.smiles, "mean": predicted.mean, "lower": predicted.lower, "upper": predicted.upper})
37+
38+
df.to_csv("predicted_values.csv")

0 commit comments

Comments
 (0)