Skip to content

Commit f555e90

Browse files
committed
fix a warning
1 parent 7c7c69f commit f555e90

File tree

6 files changed

+12
-12
lines changed

6 files changed

+12
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.16.3"
3+
version = "1.17.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch.nn as nn
1313
from torch.nn import Module
1414
from torch import Tensor, int32
15-
from torch.cuda.amp import autocast
15+
from torch.amp import autocast
1616

1717
from einops import rearrange, pack, unpack
1818

@@ -159,7 +159,7 @@ def indices_to_codes(self, indices):
159159

160160
return codes
161161

162-
@autocast(enabled = False)
162+
@autocast('cuda', enabled = False)
163163
def forward(self, z):
164164
"""
165165
einstein notation

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch import nn, einsum
1717
import torch.nn.functional as F
1818
from torch.nn import Module
19-
from torch.cuda.amp import autocast
19+
from torch.amp import autocast
2020

2121
from einops import rearrange, reduce, pack, unpack
2222

@@ -293,7 +293,7 @@ def forward(
293293

294294
force_f32 = self.force_quantization_f32
295295

296-
quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext
296+
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
297297

298298
with quantization_context():
299299

vector_quantize_pytorch/residual_fsq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch import nn
99
from torch.nn import Module, ModuleList
1010
import torch.nn.functional as F
11-
from torch.cuda.amp import autocast
11+
from torch.amp import autocast
1212

1313
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
1414

@@ -167,7 +167,7 @@ def forward(
167167

168168
# go through the layers
169169

170-
with autocast(enabled = False):
170+
with autocast('cuda', enabled = False):
171171
for quantizer_index, (layer, scale) in enumerate(zip(self.layers, self.scales)):
172172

173173
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:

vector_quantize_pytorch/residual_lfq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch import nn
77
from torch.nn import Module, ModuleList
88
import torch.nn.functional as F
9-
from torch.cuda.amp import autocast
9+
from torch.amp import autocast
1010

1111
from vector_quantize_pytorch.lookup_free_quantization import LFQ
1212

@@ -156,7 +156,7 @@ def forward(
156156

157157
# go through the layers
158158

159-
with autocast(enabled = False):
159+
with autocast('cuda', enabled = False):
160160
for quantizer_index, layer in enumerate(self.layers):
161161

162162
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.nn.functional as F
1010
import torch.distributed as distributed
1111
from torch.optim import Optimizer
12-
from torch.cuda.amp import autocast
12+
from torch.amp import autocast
1313

1414
import einx
1515
from einops import rearrange, repeat, reduce, pack, unpack
@@ -458,7 +458,7 @@ def expire_codes_(self, batch_samples):
458458
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
459459
self.replace(batch_samples, batch_mask = expired_codes)
460460

461-
@autocast(enabled = False)
461+
@autocast('cuda', enabled = False)
462462
def forward(
463463
self,
464464
x,
@@ -671,7 +671,7 @@ def expire_codes_(self, batch_samples):
671671
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
672672
self.replace(batch_samples, batch_mask = expired_codes)
673673

674-
@autocast(enabled = False)
674+
@autocast('cuda', enabled = False)
675675
def forward(
676676
self,
677677
x,

0 commit comments

Comments
 (0)