|
20 | 20 |
|
21 | 21 | @dataclass
|
22 | 22 | class SGIRMolecularPredictor(GREAMolecularPredictor):
|
23 |
| - """This predictor trains the GREA model based on pseudo-labeling and data augmentation. |
24 |
| - Paper: Semi-Supervised Graph Imbalanced Regression (https://dl.acm.org/doi/10.1145/3580305.3599497) |
25 |
| - Reference Code: https://github.com/liugangcode/SGIR |
26 |
| -
|
27 |
| - Parameters |
28 |
| - ---------- |
29 |
| - num_anchor : int, default=10 |
30 |
| - Number of anchor points used to split the label space during pseudo-labeling |
31 |
| - warmup_epoch : int, default=20 |
32 |
| - Number of epochs to train before starting pseudo-labeling and data augmentation |
33 |
| - labeling_interval : int, default=5 |
34 |
| - Interval (in epochs) between pseudo-labeling steps |
35 |
| - augmentation_interval : int, default=5 |
36 |
| - Interval (in epochs) between data augmentation steps |
37 |
| - top_quantile : float, default=0.1 |
38 |
| - Quantile threshold for selecting high confidence predictions during pseudo-labeling |
39 |
| - label_logscale : bool, default=False |
40 |
| - Whether to use log scale for the label space during pseudo-labeling and data augmentation |
41 |
| - lw_aug : float, default=1 |
42 |
| - Weight for the data augmentation loss |
| 23 | + """ |
| 24 | + This predictor trains the GREA model based on pseudo-labeling and data augmentation. |
| 25 | +
|
| 26 | + Paper: `Semi-Supervised Graph Imbalanced Regression <https://dl.acm.org/doi/10.1145/3580305.3599497>`_ |
| 27 | + Reference Code: `SGIR GitHub <https://github.com/liugangcode/SGIR>`_ |
| 28 | +
|
| 29 | + :param num_anchor: Number of anchor points used to split the label space during pseudo-labeling |
| 30 | + :type num_anchor: int, default=10 |
| 31 | + :param warmup_epoch: Number of epochs to train before starting pseudo-labeling and data augmentation |
| 32 | + :type warmup_epoch: int, default=20 |
| 33 | + :param labeling_interval: Interval (in epochs) between pseudo-labeling steps |
| 34 | + :type labeling_interval: int, default=5 |
| 35 | + :param augmentation_interval: Interval (in epochs) between data augmentation steps |
| 36 | + :type augmentation_interval: int, default=5 |
| 37 | + :param top_quantile: Quantile threshold for selecting high confidence predictions during pseudo-labeling |
| 38 | + :type top_quantile: float, default=0.1 |
| 39 | + :param label_logscale: Whether to use log scale for the label space during pseudo-labeling and data augmentation |
| 40 | + :type label_logscale: bool, default=False |
| 41 | + :param lw_aug: Weight for the data augmentation loss |
| 42 | + :type lw_aug: float, default=1 |
43 | 43 | """
|
44 | 44 | # SGIR-specific parameters
|
45 | 45 | num_anchor: int = 10
|
|
0 commit comments