Skip to content

Commit c53f6dd

Browse files
committed
Fix NumPy 2.0 compatibility: replace np.find_common_type with np.result_type - Replace deprecated np.find_common_type calls in SumLinearOperator and ProductLinearOperator - Use functools.reduce with np.result_type to achieve same functionality - Fixes AttributeError when using NumPy 2.0+
1 parent 6f74249 commit c53f6dd

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

src/probnum/linops/_arithmetic_fallbacks.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ def __init__(self, *summands: LinearOperator):
9797

9898
super().__init__(
9999
shape=summands[0].shape,
100-
dtype=np.find_common_type(
101-
[summand.dtype for summand in self._summands], []
102-
),
100+
dtype=functools.reduce(np.result_type, [summand.dtype for summand in self._summands]),
103101
matmul=lambda x: functools.reduce(
104102
operator.add, (summand @ x for summand in self._summands)
105103
),
@@ -190,7 +188,7 @@ def __init__(self, *factors: LinearOperator):
190188

191189
super().__init__(
192190
shape=(self._factors[0].shape[0], self._factors[-1].shape[1]),
193-
dtype=np.find_common_type([factor.dtype for factor in self._factors], []),
191+
dtype=functools.reduce(np.result_type, [factor.dtype for factor in self._factors]),
194192
matmul=lambda x: functools.reduce(
195193
lambda vec, op: op @ vec, reversed(self._factors), x
196194
),

0 commit comments

Comments
 (0)