Skip to content

Commit 99cba96

Browse files
Make mypy very happy
1 parent bf1ce18 commit 99cba96

File tree

2 files changed

+41
-35
lines changed

2 files changed

+41
-35
lines changed

mcbackend/backends/clickhouse.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,18 @@
55
import logging
66
import time
77
from datetime import datetime, timezone
8-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
8+
from typing import (
9+
Any,
10+
Callable,
11+
Dict,
12+
List,
13+
Mapping,
14+
Optional,
15+
Sequence,
16+
Set,
17+
Tuple,
18+
Union,
19+
)
920

1021
import clickhouse_driver
1122
import numpy
@@ -156,7 +167,7 @@ def __init__(
156167
self._client = client
157168
# The following attributes belong to the batched insert mechanism.
158169
# Inserting in batches is much faster than inserting single rows.
159-
self._str_cols = set()
170+
self._str_cols: Set[str] = set()
160171
self._insert_query: str = ""
161172
self._insert_queue: List[Dict[str, Any]] = []
162173
self._last_insert = time.time()
@@ -176,13 +187,16 @@ def append(
176187
self._insert_query = f"INSERT INTO {self.cid} (`_draw_idx`,`{names}`) VALUES"
177188
self._str_cols = {k for k, v in params.items() if "str" in numpy.asarray(v).dtype.name}
178189

179-
# Convert str ndarrays to lists
190+
params_ins: Dict[str, Union[numpy.ndarray, int, float, List[str]]] = {
191+
"_draw_idx": self._draw_idx,
192+
**params,
193+
}
194+
# Convert str-dtyped ndarrays to lists
180195
for col in self._str_cols:
181-
params[col] = params[col].tolist()
196+
params_ins[col] = params[col].tolist()
182197

183198
# Queue up for insertion
184-
params["_draw_idx"] = self._draw_idx
185-
self._insert_queue.append(params)
199+
self._insert_queue.append(params_ins)
186200
self._draw_idx += 1
187201

188202
if (
@@ -242,13 +256,14 @@ def _get_rows(
242256

243257
# Without draws return empty arrays of the correct shape/dtype
244258
if not draws:
245-
if is_rigid(nshape):
246-
return numpy.empty(shape=[0] + nshape, dtype=dtype)
259+
if is_rigid(nshape) and nshape is not None:
260+
return numpy.empty(shape=[0, *nshape], dtype=dtype)
247261
return numpy.array([], dtype=object)
248262

249263
# The unpacking must also account for non-rigid shapes
250264
# and str-dtyped empty arrays default to fixed length 1 strings.
251265
# The [None] list is slower, but more flexible in this regard.
266+
buffer: Union[numpy.ndarray, Sequence]
252267
if is_rigid(nshape) and dtype != "str":
253268
assert nshape is not None
254269
buffer = numpy.empty((draws, *nshape), dtype)
@@ -292,7 +307,7 @@ def __init__(
292307
self,
293308
meta: RunMeta,
294309
*,
295-
created_at: datetime = None,
310+
created_at: Optional[datetime] = None,
296311
client_fn: Callable[[], clickhouse_driver.Client],
297312
) -> None:
298313
self._client_fn = client_fn
@@ -331,8 +346,8 @@ class ClickHouseBackend(Backend):
331346

332347
def __init__(
333348
self,
334-
client: clickhouse_driver.Client = None,
335-
client_fn: Callable[[], clickhouse_driver.Client] = None,
349+
client: Optional[clickhouse_driver.Client] = None,
350+
client_fn: Optional[Callable[[], clickhouse_driver.Client]] = None,
336351
):
337352
"""Create a ClickHouse backend around a database client.
338353

mcbackend/core.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,20 @@
33
"""
44
import collections
55
import logging
6-
from typing import (
7-
TYPE_CHECKING,
8-
Dict,
9-
List,
10-
Mapping,
11-
Optional,
12-
Sequence,
13-
Sized,
14-
TypeVar,
15-
)
6+
from typing import Dict, List, Mapping, Optional, Sequence, Sized, TypeVar, Union, cast
167

178
import numpy
189

1910
from .meta import ChainMeta, RunMeta, Variable
2011
from .npproto.utils import ndarray_to_numpy
2112
from .utils import as_array_from_ragged
2213

23-
InferenceData = TypeVar("InferenceData")
2414
try:
25-
from arviz import from_dict
15+
from arviz import InferenceData, from_dict
2616

27-
if not TYPE_CHECKING:
28-
from arviz import InferenceData
2917
_HAS_ARVIZ = True
3018
except ModuleNotFoundError:
19+
InferenceData = TypeVar("InferenceData") # type: ignore
3120
_HAS_ARVIZ = False
3221

3322
Shape = Sequence[int]
@@ -262,20 +251,22 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
262251
warmup_sample_stats[svar.name].append(stats[tune])
263252
sample_stats[svar.name].append(stats[~tune])
264253

254+
w_pst = cast(Dict[str, Union[Sequence, numpy.ndarray]], warmup_posterior)
255+
w_ss = cast(Dict[str, Union[Sequence, numpy.ndarray]], warmup_sample_stats)
256+
pst = cast(Dict[str, Union[Sequence, numpy.ndarray]], posterior)
257+
ss = cast(Dict[str, Union[Sequence, numpy.ndarray]], sample_stats)
265258
if not equalize_chain_lengths:
266259
# Convert ragged arrays to object-dtyped ndarray because NumPy >=1.24.0 no longer does that automatically
267-
warmup_posterior = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()}
268-
warmup_sample_stats = {
269-
k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items()
270-
}
271-
posterior = {k: as_array_from_ragged(v) for k, v in posterior.items()}
272-
sample_stats = {k: as_array_from_ragged(v) for k, v in sample_stats.items()}
260+
w_pst = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()}
261+
w_ss = {k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items()}
262+
pst = {k: as_array_from_ragged(v) for k, v in posterior.items()}
263+
ss = {k: as_array_from_ragged(v) for k, v in sample_stats.items()}
273264

274265
idata = from_dict(
275-
warmup_posterior=warmup_posterior,
276-
warmup_sample_stats=warmup_sample_stats,
277-
posterior=posterior,
278-
sample_stats=sample_stats,
266+
warmup_posterior=w_pst,
267+
warmup_sample_stats=w_ss,
268+
posterior=pst,
269+
sample_stats=ss,
279270
coords=self.coords,
280271
dims=self.dims,
281272
attrs=self.meta.attributes,

0 commit comments

Comments
 (0)