Skip to content

Commit b38481a

Browse files
committed
Add tests for matrix matmul return types
Adds tests to verify that matrix multiplication returns MatrixExpr instead of MatrixVariable for various input shapes.
1 parent 6139a43 commit b38481a

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/test_matrix_variable.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,20 @@ def test_matrix_cons_indicator():
392392
assert m.getVal(is_equal).sum() == 2
393393
assert (m.getVal(x) == m.getVal(y)).all().all()
394394
assert (m.getVal(x) == np.array([[5, 5, 5], [5, 5, 5]])).all().all()
395+
396+
397+
def test_matrix_matmul_return_type():
398+
# test #1058, require returning type is MatrixExpr not MatrixVariable
399+
m = Model()
400+
401+
# test 1D @ 1D → 0D
402+
x = m.addMatrixVar(3)
403+
assert isinstance(x @ x, MatrixExpr)
404+
405+
# test 1D @ 1D → 2D
406+
assert isinstance(x[:, None] @ x[None, :], MatrixExpr)
407+
408+
# test 2D @ 2D → 2D
409+
y = m.addMatrixVar((3, 4))
410+
z = m.addMatrixVar((2, 3))
411+
assert isinstance(y @ z, MatrixExpr)

0 commit comments

Comments
 (0)