1
1
# Copyright (c) 2017-2019 Uber Technologies, Inc.
2
2
# SPDX-License-Identifier: Apache-2.0
3
3
4
- import numbers
5
- from typing import Iterator , NamedTuple , Optional , Tuple
4
+ from typing import Iterator , NamedTuple , Optional , Tuple , Union
6
5
7
6
import torch
8
7
from typing_extensions import Self
@@ -108,7 +107,7 @@ def __exit__(self, *args) -> None:
108
107
_DIM_ALLOCATOR .free (self .name , self .dim )
109
108
return super ().__exit__ (* args )
110
109
111
- def __iter__ (self ) -> Iterator [int ]:
110
+ def __iter__ (self ) -> Iterator [Union [ int , float ] ]:
112
111
if self ._vectorized is True or self .dim is not None :
113
112
raise ValueError (
114
113
"cannot use plate {} as both vectorized and non-vectorized"
@@ -121,7 +120,14 @@ def __iter__(self) -> Iterator[int]:
121
120
for i in self .indices :
122
121
self .next_context ()
123
122
with self :
124
- yield i if isinstance (i , numbers .Number ) else i .item ()
123
+ if isinstance (i , (int , float )):
124
+ yield i
125
+ elif isinstance (i , torch .Tensor ):
126
+ yield i .item ()
127
+ else :
128
+ raise ValueError (
129
+ f"Expected int, float or torch.Tensor, but got { type (i )} "
130
+ )
125
131
126
132
def _reset (self ) -> None :
127
133
if self ._vectorized :
0 commit comments