Skip to content

Commit 3951eaf

Browse files
committed
Update: pbar to global. Fix: dataclass to class. Fixes #16
1 parent ddb07ad commit 3951eaf

39 files changed

+2625
-1328
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,6 @@ scripts/*
167167
examples/temp*
168168
examples/prediction/predict_GREA_for_gas.py
169169
examples/prediction/train_GNN_for_any.py
170-
examples/prediction/train_GREA_for_gas.py
170+
examples/prediction/train_GREA_for_gas.py
171+
tools/
172+
**/*.bak

tests/generator/digress.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch_molecule import DigressMolecularGenerator
77
from torch_molecule.utils.search import ParameterType, ParameterSpec
88

9-
EPOCHS = 5
9+
EPOCHS = 100
1010
BATCH_SIZE = 32
1111

1212
def test_digress_generator():

tests/predictor/bfgnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def test_bfgnn_predictor():
2121
num_layer=3,
2222
hidden_size=128,
2323
batch_size=4,
24-
epochs=5, # Small number for testing
24+
epochs=100, # Small number for testing
25+
patience=100,
2526
verbose=True,
2627
l1_penalty=1e-3
2728
)

tests/predictor/lstm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def test_lstm_predictor():
2727
output_dim=5, # Output dimension matches number of tasks
2828
LSTMunits=60,
2929
batch_size=2,
30-
epochs=2,
30+
epochs=200,
31+
patience=200,
3132
device="cpu",
3233
verbose=True
3334
)

tests/predictor/smilestransformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def test_transformer_predictor():
2626
n_heads=4,
2727
num_layers=2,
2828
batch_size=2,
29-
epochs=2,
29+
epochs=200,
30+
patience=200,
3031
device="cpu",
3132
verbose=True,
3233
use_lr_scheduler=True,

torch_molecule/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.3"
1+
__version__ = "0.1.4"
22

33
"""
44
predictor module

torch_molecule/base/base.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,48 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass, field
3-
from typing import Optional, Dict, List, Type, Any, ClassVar, Union, Tuple, Callable, Literal
3+
from typing import Optional, Dict, List, Type, Any, Union, Tuple
44
import torch
55
import os
66
import numpy as np
77
from ..utils.checkpoint import LocalCheckpointManager, HuggingFaceCheckpointManager
88
from ..utils.checker import MolecularInputChecker
99

10-
@dataclass
1110
class BaseModel(ABC):
1211
"""Base class for molecular models with shared functionality.
1312
1413
This abstract class provides common methods and utilities for molecular models,
1514
including model initialization, saving/loading, and parameter management.
16-
"""
1715
18-
device: Optional[torch.device] = field(default=None)
19-
model_name: str = field(default="BaseModel")
20-
model_class: Optional[Type[torch.nn.Module]] = field(default=None, init=False) # used for model initialization
21-
model: Optional[torch.nn.Module] = field(default=None, init=False) # initialized model
22-
is_fitted_: bool = field(default=False, init=False)
23-
24-
def __post_init__(self):
25-
"""Initialize common device settings after instance creation.
16+
Parameters
17+
----------
18+
device : torch.device, optional
19+
Device to run the model on. If None, automatically selects CUDA if available,
20+
otherwise CPU.
21+
model_name : str, default="BaseModel"
22+
String identifier for the model name which can be specified by the user.
23+
24+
Attributes
25+
----------
26+
model_class : type or None
27+
The class of the model used to initialize the model instance.
28+
model : object or None
29+
The fitted model instance if the model has been trained, None otherwise.
30+
is_fitted_ : bool
31+
Whether the model has been fitted/trained. False by default.
32+
"""
33+
def __init__(self, device: Optional[torch.device] = None, model_name: str = "BaseModel"):
34+
self.device = device
35+
self.model_name = model_name # string of the model name which could be specified by the user
2636

27-
Sets the device to CUDA if available, otherwise CPU, when no device is specified.
28-
Converts string device specifications to torch.device objects.
29-
"""
3037
if self.device is None:
3138
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3239
elif isinstance(self.device, str):
3340
self.device = torch.device(self.device)
3441

42+
self.is_fitted_ = False # whether the model is fitted
43+
self.model = None # the fitted model if not None
44+
self.model_class = None # the class of the model used to initialize the model
45+
3546
@abstractmethod
3647
def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
3748
"""Set up optimizers for model training.
@@ -78,7 +89,7 @@ def _get_model_params(self, checkpoint: Optional[Dict] = None) -> Dict[str, Any]
7889
pass
7990

8091
@staticmethod
81-
def _get_param_names(self) -> List[str]:
92+
def _get_param_names() -> List[str]:
8293
"""Get parameter names in the modeling class.
8394
8495
Returns
@@ -104,7 +115,7 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
104115
Dictionary of parameter names mapped to their values
105116
"""
106117
out = {}
107-
for key in self._get_param_names():
118+
for key in self.__class__._get_param_names():
108119
value = getattr(self, key)
109120
if deep and hasattr(value, "get_params"):
110121
deep_items = value.get_params().items()
@@ -392,5 +403,4 @@ def format_value(v):
392403
if len(repr_str) > N_CHAR_MAX:
393404
repr_str = "\n".join([repr_str[:N_CHAR_MAX//2], "...", repr_str[-N_CHAR_MAX//2:]])
394405

395-
return repr_str
396-
406+
return repr_str

torch_molecule/base/encoder.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1-
from dataclasses import dataclass, field
21
from abc import ABC, abstractmethod
3-
from typing import Optional, ClassVar, Union, List, Dict, Any, Tuple, Callable, Type, Literal
2+
from typing import Optional, Union, List, Literal
43

54
import torch
65
import numpy as np
76
from .base import BaseModel
87

9-
@dataclass
108
class BaseMolecularEncoder(BaseModel, ABC):
119
"""Base class for molecular representation learning."""
12-
13-
model_name: str = field(default="BaseMolecularEncoder")
14-
10+
def __init__(
11+
self,
12+
*,
13+
device: Optional[Union[torch.device, str]] = None,
14+
model_name: str = "BaseMolecularEncoder",
15+
):
16+
super().__init__(device=device, model_name=model_name)
17+
1518
@abstractmethod
1619
def encode(self, X: List[str], return_type: Literal["np", "pt"] = "pt") -> Union[np.ndarray, torch.Tensor]:
1720
pass

torch_molecule/base/generator.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
from dataclasses import dataclass, field
21
from abc import ABC, abstractmethod
3-
from typing import Optional, ClassVar, Union, List, Dict, Any, Tuple, Callable, Type, Literal
4-
2+
from typing import Optional, List, Union
53
import torch
64
import numpy as np
75
from .base import BaseModel
86

9-
@dataclass
107
class BaseMolecularGenerator(BaseModel, ABC):
118
"""Base class for molecular generation."""
12-
13-
model_name: str = field(default="BaseMolecularGenerator")
9+
def __init__(
10+
self,
11+
*,
12+
device: Optional[Union[torch.device, str]] = None,
13+
model_name: str = "BaseMolecularGenerator",
14+
):
15+
super().__init__(device=device, model_name=model_name)
1416

1517
@abstractmethod
1618
def fit(self, X: List[str], y: Optional[np.ndarray] = None) -> "BaseMolecularGenerator":
@@ -20,5 +22,4 @@ def fit(self, X: List[str], y: Optional[np.ndarray] = None) -> "BaseMolecularGen
2022
def generate(self, n_samples: int, **kwargs) -> List[str]:
2123
"""Generate molecular structures.
2224
"""
23-
pass
24-
25+
pass

torch_molecule/base/predictor.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,32 @@
99
root_mean_squared_error,
1010
r2_score,
1111
)
12-
from dataclasses import dataclass, field
1312
from abc import ABC, abstractmethod
14-
from typing import Optional, ClassVar, Union, List, Dict, Any, Tuple, Callable, Type
13+
from typing import Optional, Union, List, Tuple, Callable
1514
from ..base.base import BaseModel
1615

17-
@dataclass
1816
class BaseMolecularPredictor(BaseModel, ABC):
1917
"""Base class for molecular discovery estimators."""
20-
21-
model_name: str = field(default="BaseMolecularPredictor")
22-
num_task: int = field(default=0)
23-
task_type: str = field(default=None)
24-
DEFAULT_METRICS: ClassVar[Dict] = {
25-
"classification": {"default": ("roc_auc", roc_auc_score, True)},
26-
"regression": {"default": ("mae", mean_absolute_error, False)},
27-
}
18+
def __init__(
19+
self,
20+
*,
21+
device: Optional[Union[torch.device, str]] = None,
22+
model_name: str = "BaseMolecularPredictor",
23+
num_task: int = 0,
24+
task_type: Optional[str] = None,
25+
):
26+
super().__init__(device=device, model_name=model_name)
27+
self.num_task = num_task
28+
self.task_type = task_type
2829

29-
def __post_init__(self):
30-
super().__post_init__()
3130
if self.task_type not in ["classification", "regression"]:
3231
raise ValueError(f"Invalid task_type: {self.task_type}")
3332
if self.num_task <= 0:
3433
raise ValueError(f"num_task must be positive, got {self.num_task}")
3534

3635
@staticmethod
37-
def _get_param_names(self) -> List[str]:
38-
return super()._get_param_names() + ["num_task", "task_type"]
36+
def _get_param_names() -> List[str]:
37+
return BaseModel._get_param_names() + ["num_task", "task_type"]
3938

4039
@abstractmethod
4140
def autofit(self, X_train, y_train, X_val=None, y_val=None, search_parameters: Optional[dict] = None, n_trials: int = 10) -> "BaseMolecularPredictor":
@@ -59,10 +58,16 @@ def _setup_evaluation(
5958
evaluate_higher_better: Optional[bool],
6059
) -> None:
6160
if evaluate_criterion is None:
62-
default_metric = self.DEFAULT_METRICS[self.task_type]["default"]
63-
self.evaluate_name = default_metric[0]
64-
self.evaluate_criterion = default_metric[1]
65-
self.evaluate_higher_better = default_metric[2]
61+
if self.task_type == 'classification':
62+
self.evaluate_name = 'roc_auc'
63+
self.evaluate_criterion = roc_auc_score
64+
self.evaluate_higher_better = True
65+
elif self.task_type == 'regression':
66+
self.evaluate_name = 'mae'
67+
self.evaluate_criterion = mean_absolute_error
68+
self.evaluate_higher_better = False
69+
else:
70+
raise ValueError(f"The task type {self.task_type} does not have a default metric.")
6671
else:
6772
if isinstance(evaluate_criterion, str):
6873
metric_map = {

0 commit comments

Comments
 (0)