Skip to content

Commit e54f339

Browse files
Run mypy in pre-commit
1 parent 99cba96 commit e54f339

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,7 @@ repos:
3131
args: [--rcfile=.pylintrc]
3232
exclude: (test_*|mcbackend/meta.py|mcbackend/npproto/)
3333
files: ^mcbackend/
34+
- repo: https://github.com/pre-commit/mirrors-mypy
35+
rev: v0.991
36+
hooks:
37+
- id: mypy

mcbackend/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def test_chain_properties(self):
7272

7373
def test_chain_length(self):
7474
class _TestChain(core.Chain):
75-
def get_draws(self, var_name: str):
75+
def get_draws(self, var_name: str, slc: slice = slice(None)):
7676
return numpy.arange(12)
7777

78-
def get_stats(self, stat_name: str):
78+
def get_stats(self, stat_name: str, slc: slice = slice(None)):
7979
return numpy.arange(42)
8080

8181
rmeta = RunMeta("test", variables=[Variable("v1")])

mcbackend/test_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import random
22
import time
33
from dataclasses import dataclass
4-
from typing import Sequence
4+
from typing import Optional, Sequence
55

66
import arviz
77
import hagelkorn
@@ -78,9 +78,9 @@ def make_draw(variables: Sequence[Variable]):
7878
class BaseBackendTest:
7979
"""Can be used to test different backends in the same way."""
8080

81-
cls_backend = None
82-
cls_run = None
83-
cls_chain = None
81+
cls_backend: Optional[type] = None
82+
cls_run: Optional[type] = None
83+
cls_chain: Optional[type] = None
8484

8585
def setup_method(self, method):
8686
"""Override this when the backend has no parameterless constructor."""
@@ -373,10 +373,8 @@ def run_all_benchmarks(self) -> pandas.DataFrame:
373373
for attr in dir(BackendBenchmark):
374374
meth = getattr(self, attr, None)
375375
if callable(meth) and meth.__name__.startswith("measure_"):
376-
try:
376+
if hasattr(self, "setup_method"):
377377
self.setup_method(meth)
378-
except TypeError:
379-
pass
380378
print(f"Running {meth.__name__}")
381379
speed = meth()
382380
df.loc[meth.__name__[8:], ["bytes_per_draw", "append_speed", "description"]] = (

0 commit comments

Comments
 (0)