Skip to content

Commit fcd09fe

Browse files
authored
Fix issue with new caching mechanism in transformers and bump versions (#313)
*Issue #, if available:* Fixes #310 and closes #302 *Description of changes:* This PR fixes an issue related to the new caching mechanism for T5 introduced in `transformers==4.54`. [Prior versions set](https://github.com/huggingface/transformers/blob/v4.53.3/src/transformers/models/t5/modeling_t5.py#L1328) `encoder_config.is_encoder_decoder = False` when initializing encoder and decoder. Following transformers, we also initialized Chronos-Bolt in the same way. However, in v4.54 this line [has been removed](https://github.com/huggingface/transformers/blob/3fd456b200ba434e567412cc4517309482653f60/src/transformers/models/t5/modeling_t5.py#L1301) and [new logic has been added](https://github.com/huggingface/transformers/blob/3fd456b200ba434e567412cc4517309482653f60/src/transformers/models/t5/modeling_t5.py#L494) which relies on `is_encoder_decoder` [being True](https://github.com/huggingface/transformers/blob/3fd456b200ba434e567412cc4517309482653f60/src/transformers/models/t5/modeling_t5.py#L1007). This causes Chronos-Bolt to break as described in #310. This PR removes `is_encoder_decoder = False` for both encoder and decoder which fixes the issue. I re-ran our mini eval in the CI and got the same results for v4.54 and v4.48 (our current lower bound). This PR also bumps package versions. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent 6a9c8da commit fcd09fe

File tree

3 files changed

+42
-24
lines changed

3 files changed

+42
-24
lines changed

pyproject.toml

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
[project]
22
name = "chronos-forecasting"
3-
version = "1.5.2"
3+
version = "1.5.3"
44
authors = [
5-
{ name="Abdul Fatir Ansari", email="ansarnd@amazon.com" },
6-
{ name="Lorenzo Stella", email="stellalo@amazon.com" },
7-
{ name="Caner Turkmen", email="atturkm@amazon.com" },
5+
{ name = "Abdul Fatir Ansari", email = "ansarnd@amazon.com" },
6+
{ name = "Lorenzo Stella", email = "stellalo@amazon.com" },
7+
{ name = "Caner Turkmen", email = "atturkm@amazon.com" },
88
]
99
description = "Chronos: Pretrained models for time series forecasting"
1010
readme = "README.md"
1111
license = { file = "LICENSE" }
1212
requires-python = ">=3.9"
13-
dependencies = [
14-
"torch>=2.0,<3", # package was tested on 2.2
15-
"transformers>=4.48,<5",
16-
"accelerate>=0.32,<2",
17-
]
13+
dependencies = ["torch>=2.0,<3", "transformers>=4.48,<5", "accelerate>=0.32,<2"]
1814
classifiers = [
19-
"Programming Language :: Python :: 3",
20-
"License :: OSI Approved :: Apache Software License",
21-
"Operating System :: OS Independent",
22-
"Topic :: Scientific/Engineering :: Artificial Intelligence",
15+
"Programming Language :: Python :: 3",
16+
"License :: OSI Approved :: Apache Software License",
17+
"Operating System :: OS Independent",
18+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
2319
]
2420

2521
[build-system]
@@ -30,10 +26,24 @@ build-backend = "hatchling.build"
3026
packages = ["src/chronos"]
3127

3228
[project.optional-dependencies]
33-
test = ["pytest~=8.0", "numpy~=1.21"]
29+
test = ["pytest~=8.0", "numpy>=1.21,<3"]
3430
typecheck = ["mypy~=1.9"]
35-
training = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer", "typer-config", "joblib", "scikit-learn", "tensorboard"]
36-
evaluation = ["gluonts[pro]~=0.15", "numpy~=1.21", "datasets~=2.18", "typer"]
31+
training = [
32+
"gluonts[pro]~=0.15",
33+
"numpy>=1.21,<3",
34+
"datasets>=2.18,<4",
35+
"typer",
36+
"typer-config",
37+
"joblib",
38+
"scikit-learn",
39+
"tensorboard",
40+
]
41+
evaluation = [
42+
"gluonts[pro]~=0.15",
43+
"numpy>=1.21,<3",
44+
"datasets>=2.18,<4",
45+
"typer",
46+
]
3747

3848
[project.urls]
3949
Homepage = "https://github.com/amazon-science/chronos-forecasting"
@@ -42,3 +52,11 @@ Paper = "https://arxiv.org/abs/2403.07815"
4252

4353
[tool.mypy]
4454
ignore_missing_imports = true
55+
56+
[tool.ruff]
57+
line-length = 88
58+
lint.ignore = [
59+
"E501", # Line too long
60+
"E731", # Do not assign a `lambda` expression, use a `def`
61+
"E722", # Do not use bare `except`
62+
]

src/chronos/chronos.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,12 @@ def encode(
302302
A tensor of encoder embeddings with shape
303303
(batch_size, sequence_length, d_model).
304304
"""
305-
assert (
306-
self.config.model_type == "seq2seq"
307-
), "Encoder embeddings are only supported for encoder-decoder models"
308-
assert hasattr(self.model, "encoder")
305+
assert self.config.model_type == "seq2seq", (
306+
"Encoder embeddings are only supported for encoder-decoder models"
307+
)
308+
assert hasattr(self.model, "encoder") and isinstance(
309+
self.model.encoder, nn.Module
310+
)
309311

310312
return self.model.encoder(
311313
input_ids=input_ids, attention_mask=attention_mask
@@ -346,7 +348,7 @@ def forward(
346348
if top_p is None:
347349
top_p = self.config.top_p
348350

349-
assert hasattr(self.model, "generate")
351+
assert callable(getattr(self.model, "generate", None))
350352

351353
preds = self.model.generate(
352354
input_ids=input_ids,
@@ -362,7 +364,7 @@ def forward(
362364
top_k=top_k,
363365
top_p=top_p,
364366
),
365-
)
367+
) # type: ignore
366368

367369
if self.config.model_type == "seq2seq":
368370
preds = preds[..., 1:] # remove the decoder start token

src/chronos/chronos_bolt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def __init__(self, config: T5Config):
179179
encoder_config = copy.deepcopy(config)
180180
encoder_config.is_decoder = False
181181
encoder_config.use_cache = False
182-
encoder_config.is_encoder_decoder = False
183182
self.encoder = T5Stack(encoder_config, self.shared)
184183

185184
self._init_decoder(config)
@@ -381,7 +380,6 @@ def forward(
381380
def _init_decoder(self, config):
382381
decoder_config = copy.deepcopy(config)
383382
decoder_config.is_decoder = True
384-
decoder_config.is_encoder_decoder = False
385383
decoder_config.num_layers = config.num_decoder_layers
386384
self.decoder = T5Stack(decoder_config, self.shared)
387385

0 commit comments

Comments
 (0)