Skip to content

Commit 24b7a7f

Browse files
author
Yerdos Ordabayev
authored
Fix IndepMessenger.__iter__ type annotation (#3406)
1 parent 2fdbff9 commit 24b7a7f

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

pyro/nn/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __get__(
138138
if name not in obj.__dict__["_pyro_params"]:
139139
init_value, constraint, event_dim = self
140140
# bind method's self arg
141-
init_value = functools.partial(init_value, obj) # type: ignore[arg-type,misc,operator]
141+
init_value = functools.partial(init_value, obj) # type: ignore[arg-type,call-arg,misc,operator]
142142
setattr(obj, name, PyroParam(init_value, constraint, event_dim))
143143
value: PyroParam = obj.__getattr__(name)
144144
return value

pyro/poutine/indep_messenger.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Copyright (c) 2017-2019 Uber Technologies, Inc.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import numbers
5-
from typing import Iterator, NamedTuple, Optional, Tuple
4+
from typing import Iterator, NamedTuple, Optional, Tuple, Union
65

76
import torch
87
from typing_extensions import Self
@@ -108,7 +107,7 @@ def __exit__(self, *args) -> None:
108107
_DIM_ALLOCATOR.free(self.name, self.dim)
109108
return super().__exit__(*args)
110109

111-
def __iter__(self) -> Iterator[int]:
110+
def __iter__(self) -> Iterator[Union[int, float]]:
112111
if self._vectorized is True or self.dim is not None:
113112
raise ValueError(
114113
"cannot use plate {} as both vectorized and non-vectorized"
@@ -121,7 +120,14 @@ def __iter__(self) -> Iterator[int]:
121120
for i in self.indices:
122121
self.next_context()
123122
with self:
124-
yield i if isinstance(i, numbers.Number) else i.item()
123+
if isinstance(i, (int, float)):
124+
yield i
125+
elif isinstance(i, torch.Tensor):
126+
yield i.item()
127+
else:
128+
raise ValueError(
129+
f"Expected int, float or torch.Tensor, but got {type(i)}"
130+
)
125131

126132
def _reset(self) -> None:
127133
if self._vectorized:

0 commit comments

Comments
 (0)