Skip to content

Commit 3710d2f

Browse files
NeelKondapallipre-commit-ci[bot]akihironitta
authored
Use weights_only=True from PyTorch 2.4 (#423)
Fixes #422 by adding the `weights_only = True` argument to `torch.load` in the file `io.py`. This protects agains the arbitrary data warning. The types `stype` and `StatType` were added to the safe globals list. By: Neel Kondapalli (neel2h06@gmail.com) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
1 parent 59994ec commit 3710d2f

File tree

4 files changed

+25
-4
lines changed

4 files changed

+25
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
### Changed
1313

14+
- Set `weights_only=True` in `torch_frame.load` from PyTorch 2.4 ([#423](https://github.com/pyg-team/pytorch-frame/pull/423))
15+
1416
### Deprecated
1517

1618
### Removed

torch_frame/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,27 @@
1212
embedding,
1313
)
1414
from .data import TensorFrame
15-
from .typing import TaskType, Metric, DataFrame, NAStrategy
15+
from .typing import (
16+
TaskType,
17+
Metric,
18+
DataFrame,
19+
NAStrategy,
20+
WITH_PT24,
21+
)
1622
from torch_frame.utils import save, load, cat # noqa
1723
import torch_frame.data # noqa
1824
import torch_frame.datasets # noqa
1925
import torch_frame.nn # noqa
2026
import torch_frame.gbdt # noqa
2127

28+
if WITH_PT24:
29+
import torch
30+
31+
torch.serialization.add_safe_globals([
32+
stype,
33+
torch_frame.data.stats.StatType,
34+
])
35+
2236
__version__ = '0.2.3'
2337

2438
__all__ = [

torch_frame/typing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
from typing import Dict, List, Mapping, Union
55

66
import pandas as pd
7+
import torch
78
from torch import Tensor
89

910
from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor
1011
from torch_frame.data.multi_nested_tensor import MultiNestedTensor
1112

13+
WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2
14+
WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
15+
1216

1317
class Metric(Enum):
1418
r"""The metric.

torch_frame/utils/io.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from torch_frame.data.multi_tensor import _MultiTensor
1515
from torch_frame.data.stats import StatType
16-
from torch_frame.typing import TensorData
16+
from torch_frame.typing import WITH_PT24, TensorData
1717

1818

1919
def serialize_feat_dict(
@@ -80,7 +80,8 @@ def save(tensor_frame: TensorFrame,
8080

8181

8282
def load(
83-
path: str, device: torch.device | None = None
83+
path: str,
84+
device: torch.device | None = None,
8485
) -> tuple[TensorFrame, dict[str, dict[StatType, Any]] | None]:
8586
r"""Load saved :class:`TensorFrame` object and optional :obj:`col_stats`
8687
from a specified path.
@@ -95,7 +96,7 @@ def load(
9596
tuple: A tuple of loaded :class:`TensorFrame` object and
9697
optional :obj:`col_stats`.
9798
"""
98-
tf_dict, col_stats = torch.load(path)
99+
tf_dict, col_stats = torch.load(path, weights_only=WITH_PT24)
99100
tf_dict['feat_dict'] = deserialize_feat_dict(
100101
tf_dict.pop('feat_serialized_dict'))
101102
tensor_frame = TensorFrame(**tf_dict)

0 commit comments

Comments
 (0)