diff --git a/CHANGELOG.md b/CHANGELOG.md index addf21992..868656025 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ ### Fixed - Raised an error when an expression is used when a variable is required - Fixed some compile warnings +- Fixed the type of @ matrix operation result from MatrixVariable to MatrixExpr. ### Changed - MatrixExpr.sum() now supports axis arguments and can return either a scalar or MatrixExpr, depending on the result dimensions. - AddMatrixCons() also accepts ExprCons. diff --git a/src/pyscipopt/matrix.pxi b/src/pyscipopt/matrix.pxi index 0548fbd10..9eb0349f2 100644 --- a/src/pyscipopt/matrix.pxi +++ b/src/pyscipopt/matrix.pxi @@ -98,7 +98,10 @@ class MatrixExpr(np.ndarray): def __rsub__(self, other): return super().__rsub__(other).view(MatrixExpr) - + + def __matmul__(self, other): + return super().__matmul__(other).view(MatrixExpr) + class MatrixGenExpr(MatrixExpr): pass diff --git a/tests/test_matrix_variable.py b/tests/test_matrix_variable.py index 0308bb694..c7db949b3 100644 --- a/tests/test_matrix_variable.py +++ b/tests/test_matrix_variable.py @@ -392,3 +392,20 @@ def test_matrix_cons_indicator(): assert m.getVal(is_equal).sum() == 2 assert (m.getVal(x) == m.getVal(y)).all().all() assert (m.getVal(x) == np.array([[5, 5, 5], [5, 5, 5]])).all().all() + + +def test_matrix_matmul_return_type(): + # test #1058, require returning type is MatrixExpr not MatrixVariable + m = Model() + + # test 1D @ 1D → 0D + x = m.addMatrixVar(3) + assert type(x @ x) is MatrixExpr + + # test 1D @ 1D → 2D + assert type(x[:, None] @ x[None, :]) is MatrixExpr + + # test 2D @ 2D → 2D + y = m.addMatrixVar((2, 3)) + z = m.addMatrixVar((3, 4)) + assert type(y @ z) is MatrixExpr