Skip to content

Commit 2eadef7

Browse files
committed
update dataset load output and add zinc250k
1 parent 9675959 commit 2eadef7

File tree

7 files changed

+166
-87
lines changed

7 files changed

+166
-87
lines changed

README.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ See the [List of Supported Models](#list-of-supported-models) section for all av
7070
> More examples can be found in the `examples` and `tests` folders.
7171
7272
`torch-molecule` supports applications in broad domains from chemistry, biology, to materials science. To get started, you can load prepared datasets from `torch_molecule.datasets` (updated after v0.1.3):
73-
7473
| Dataset | Description | Function |
7574
|---------|-------------|----------|
7675
| qm9 | Quantum chemical properties (DFT level) | `load_qm9` |
@@ -79,23 +78,30 @@ See the [List of Supported Models](#list-of-supported-models) section for all av
7978
| toxcast | Toxicity of chemical compounds | `load_toxcast` |
8079
| admet | Chemical absorption, distribution, metabolism, excretion, and toxicity | `load_admet` |
8180
| gasperm | Six gas permeability properties for polymeric materials | `load_gasperm` |
82-
81+
| zinc250k | A common subset of ZINC dataset, which does not have labels and could be used for unconditional generation or virtual screening | `load_zinc250k` |
8382

8483
```python
8584
from torch_molecule.datasets import load_qm9
8685

8786
# local_dir is the local path where the dataset will be saved
88-
smiles_list, property_np_array = load_qm9(local_dir='torchmol_data')
87+
molecular_data = load_qm9(local_dir='torchmol_data')
88+
smiles_list, property_np_array = molecular_data.data, molecular_data.target
8989

9090
# len(smiles_list): 133885
9191
# Property array shape: (133885, 1)
9292

9393
# load_qm9 returns the target "gap" by default, but you can adjust it by passing new target_cols
9494
target_cols = ['homo', 'lumo', 'gap']
95-
smiles_list, property_np_array = load_qm9(local_dir='torchmol_data', target_cols=target_cols)
95+
molecular_data = load_qm9(local_dir='torchmol_data', target_cols=target_cols)
96+
smiles_list, property_np_array = molecular_data.data, molecular_data.target
97+
98+
# the target could be None if loading an unlabeled dataset
99+
molecular_data = load_zinc250k(local_dir='torchmol_data', target_cols=target_cols)
100+
smiles_list = molecular_data.data
101+
assert molecular_data.target is None
96102
```
97103

98-
(We welcome your suggestions and contributions on your datasets!)
104+
(We are actively adding more datasets. We welcome your suggestions and contributions on your datasets!)
99105

100106
### Fit a Model
101107

tests/datasets/gasperm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def test_gasperm_download_and_cleanup():
1515
print("-" * 40)
1616

1717
# Test with default target columns
18-
smiles_list, property_numpy = load_gasperm()
18+
molecular_dataset = load_gasperm()
19+
smiles_list = molecular_dataset.data
20+
property_numpy = molecular_dataset.target
1921

2022
# Print results
2123
print(f"\nResults:")
@@ -52,7 +54,9 @@ def test_gasperm_download_and_cleanup():
5254
print("-" * 40)
5355

5456
custom_targets = ["CH4", "CO2"]
55-
smiles_list2, property_numpy2 = load_gasperm(target_cols=custom_targets)
57+
molecular_dataset2 = load_gasperm(target_cols=custom_targets)
58+
smiles_list2 = molecular_dataset2.data
59+
property_numpy2 = molecular_dataset2.target
5660

5761
print(f"Custom target results:")
5862
print(f"- Target columns: {custom_targets}")
@@ -65,7 +69,9 @@ def test_gasperm_download_and_cleanup():
6569
print("-" * 40)
6670

6771
single_target = ["H2"]
68-
smiles_list3, property_numpy3 = load_gasperm(target_cols=single_target)
72+
molecular_dataset3 = load_gasperm(target_cols=single_target)
73+
smiles_list3 = molecular_dataset3.data
74+
property_numpy3 = molecular_dataset3.target
6975

7076
print(f"Single target results:")
7177
print(f"- Target columns: {single_target}")
@@ -80,7 +86,9 @@ def test_gasperm_download_and_cleanup():
8086

8187
try:
8288
invalid_targets = ["INVALID_GAS"]
83-
smiles_list4, property_numpy4 = load_gasperm(target_cols=invalid_targets)
89+
molecular_dataset4 = load_gasperm(target_cols=invalid_targets)
90+
smiles_list4 = molecular_dataset4.data
91+
property_numpy4 = molecular_dataset4.target
8492
print("ERROR: Should have raised ValueError for invalid target column")
8593
except ValueError as e:
8694
print(f"Successfully caught expected error: {e}")

tests/datasets/hf.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import tempfile
33
import shutil
4-
from torch_molecule.datasets import load_qm9, load_chembl2k, load_broad6k, load_toxcast, load_admet
4+
from torch_molecule.datasets import load_qm9, load_chembl2k, load_broad6k, load_toxcast, load_admet, load_zinc250k
55
import numpy as np
66
import csv
77
import gzip
@@ -17,6 +17,8 @@ def load_dataset(dataset_name="qm9"):
1717
return load_toxcast
1818
elif dataset_name == "admet":
1919
return load_admet
20+
elif dataset_name == "zinc250k":
21+
return load_zinc250k
2022
else:
2123
raise ValueError(f"Dataset {dataset_name} not found")
2224

@@ -42,57 +44,70 @@ def test_download_and_cleanup(dataset_name="qm9"):
4244
print("-" * 40)
4345

4446
# Test with default target columns
45-
smiles_list, property_numpy, local_data_path = load_func(
47+
result = load_func(
4648
local_dir=test_csv_path,
4749
return_local_data_path=True,
4850
)
51+
molecular_dataset, local_data_path = result
4952

5053
# Print results
5154
print(f"\nResults:")
52-
print(f"- Number of molecules: {len(smiles_list)}")
53-
print(f"- Property array shape: {property_numpy.shape}")
55+
print(f"- Number of molecules: {len(molecular_dataset.data)}")
56+
print(f"- Property array shape: {molecular_dataset.target.shape if molecular_dataset.target is not None else 'None'}")
5457
print(f"- File exists: {os.path.exists(local_data_path)}")
5558
print(f"- File size: {os.path.getsize(local_data_path) if os.path.exists(local_data_path) else 0} bytes")
5659

5760
print(f"\nFirst 5 SMILES:")
58-
for i, smiles in enumerate(smiles_list[:5]):
61+
for i, smiles in enumerate(molecular_dataset.data[:5]):
5962
print(f" {i+1}. {smiles}")
6063

6164
print(f"\nFirst 5 property values (gap):")
62-
for i, prop in enumerate(property_numpy[:5]):
63-
print(f" {i+1}. {prop[0]:.6f}")
65+
if molecular_dataset.target is not None:
66+
for i, prop in enumerate(molecular_dataset.target[:5]):
67+
print(f" {i+1}. {prop[0]:.6f}")
68+
else:
69+
print(" No property values available (target is None)")
6470

6571
print(f"\nProperty statistics:")
6672
# Calculate statistics excluding NaN values
67-
non_null_mask = ~np.isnan(property_numpy)
68-
non_null_values = property_numpy[non_null_mask]
69-
70-
print(f" Total values: {property_numpy.size}")
71-
print(f" Non-null values: {non_null_values.size}")
72-
print(f" Null values: {property_numpy.size - non_null_values.size}")
73-
print(f" Non-null percentage: {(non_null_values.size / property_numpy.size * 100):.2f}%")
74-
75-
if non_null_values.size > 0:
76-
print(f" Min (non-null): {non_null_values.min():.6f}")
77-
print(f" Max (non-null): {non_null_values.max():.6f}")
78-
print(f" Mean (non-null): {non_null_values.mean():.6f}")
79-
print(f" Std (non-null): {non_null_values.std():.6f}")
73+
if molecular_dataset.target is not None:
74+
non_null_mask = ~np.isnan(molecular_dataset.target)
75+
non_null_values = molecular_dataset.target[non_null_mask]
76+
77+
print(f" Total values: {molecular_dataset.target.size}")
78+
print(f" Non-null values: {non_null_values.size}")
79+
print(f" Null values: {molecular_dataset.target.size - non_null_values.size}")
80+
print(f" Non-null percentage: {(non_null_values.size / molecular_dataset.target.size * 100):.2f}%")
81+
82+
if non_null_values.size > 0:
83+
print(f" Min (non-null): {non_null_values.min():.6f}")
84+
print(f" Max (non-null): {non_null_values.max():.6f}")
85+
print(f" Mean (non-null): {non_null_values.mean():.6f}")
86+
print(f" Std (non-null): {non_null_values.std():.6f}")
87+
else:
88+
print(" No non-null values found")
8089
else:
81-
print(" No non-null values found")
90+
print(" No property statistics available (target is None)")
8291

8392
# Test loading from existing file (should not download again)
8493
print(f"\n2. Testing loading from existing file")
8594
print("-" * 40)
8695

87-
smiles_list2, property_numpy2, local_data_path = load_func(
96+
result2 = load_func(
8897
local_dir=test_csv_path,
8998
return_local_data_path=True,
9099
)
100+
molecular_dataset2, local_data_path2 = result2
91101

92102
print(f"Second load results:")
93-
print(f"- Same number of molecules: {len(smiles_list2) == len(smiles_list)}")
94-
print(f"- Same property shape: {property_numpy2.shape == property_numpy.shape}")
95-
print(f"- Local data path: {local_data_path}")
103+
print(f"- Same number of molecules: {len(molecular_dataset2.data) == len(molecular_dataset.data)}")
104+
if molecular_dataset.target is not None and molecular_dataset2.target is not None:
105+
print(f"- Same property shape: {molecular_dataset2.target.shape == molecular_dataset.target.shape}")
106+
elif molecular_dataset.target is None and molecular_dataset2.target is None:
107+
print(f"- Same property shape: True (both are None)")
108+
else:
109+
print(f"- Same property shape: False (different None status)")
110+
print(f"- Local data path: {local_data_path2}")
96111

97112
# Test with multiple target columns (if available)
98113
print(f"\n3. Testing with multiple target columns")
@@ -132,8 +147,9 @@ def test_download_and_cleanup(dataset_name="qm9"):
132147

133148

134149
if __name__ == "__main__":
135-
test_download_and_cleanup(dataset_name="qm9")
136-
test_download_and_cleanup(dataset_name="chembl2k")
137-
test_download_and_cleanup(dataset_name="broad6k")
138-
test_download_and_cleanup(dataset_name="toxcast")
139-
test_download_and_cleanup(dataset_name="admet")
150+
# test_download_and_cleanup(dataset_name="qm9")
151+
# test_download_and_cleanup(dataset_name="chembl2k")
152+
# test_download_and_cleanup(dataset_name="broad6k")
153+
# test_download_and_cleanup(dataset_name="toxcast")
154+
# test_download_and_cleanup(dataset_name="admet")
155+
test_download_and_cleanup(dataset_name="zinc250k")

torch_molecule/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .load_hf_dataset import load_qm9, load_chembl2k, load_broad6k, load_toxcast, load_admet
1+
from .load_hf_dataset import load_qm9, load_chembl2k, load_broad6k, load_toxcast, load_admet, load_zinc250k
22
from .load_local_csv import load_gasperm
33

44
__all__ = [
@@ -8,4 +8,5 @@
88
"load_toxcast",
99
"load_admet",
1010
"load_gasperm",
11+
"load_zinc250k",
1112
]

torch_molecule/datasets/constant.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
from dataclasses import dataclass
2+
from typing import List
3+
import numpy as np
4+
5+
@dataclass
6+
class SMILESDataset:
7+
"""
8+
Data class for storing molecular SMILES dataset with input and target data.
9+
10+
Attributes:
11+
data (List[str]): Input data (e.g., list of SMILES strings)
12+
target (np.ndarray | None): Target property values as 2D numpy array (rows=molecules, cols=targets) or None
13+
"""
14+
data: List[str]
15+
target: np.ndarray | None
16+
17+
118
TOXCAST_TASKS = [
219
'ACEA_T47D_80hr_Negative', 'ACEA_T47D_80hr_Positive',
320
'APR_HepG2_CellCycleArrest_24h_dn', 'APR_HepG2_CellCycleArrest_24h_up',

0 commit comments

Comments
 (0)