Skip to content

Commit 30bb492

Browse files
authored
Allow type checking on Python 3.11 (#1181)
2 parents e1e475a + dc5e117 commit 30bb492

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

ethicml/implementations/beutel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def forward(ctx: Any, x: Tensor, lambda_: float) -> Any:
366366
return x.view_as(x)
367367

368368
@staticmethod
369-
def backward(ctx: Any, grad_output: Tensor) -> Any: # type: ignore[override]
369+
def backward(ctx: Any, grad_output: Tensor) -> Any: # pyright: ignore
370370
"""Backward pass with Gradient reversed / inverted."""
371371
return grad_output.neg().mul(ctx.lambda_), None
372372

ethicml/metrics/dependence_measures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,4 @@ def _corr(x: np.ndarray, y: np.ndarray) -> float:
156156

157157
def _count_true(mask: np.ndarray) -> int:
158158
"""Count the number of elements that are True."""
159-
return np.count_nonzero(mask)
159+
return np.count_nonzero(mask).item()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ typecheck = [
7575
"pandas-stubs>=1.4.2.220626",
7676
"scipy-stubs>=1.15.3.0",
7777
"types-seaborn<1.0.0.0,>=0.13.2.20240205",
78+
"backports.strenum>=1.3.1",
7879
]
7980
lint = [
8081
"ruff>=0.2.2",

uv.lock

Lines changed: 5 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)