Skip to content

Commit 815671d

Browse files
committed
Fix shape errors in scalar_solve_to_division
1 parent 80acf20 commit 815671d

File tree

2 files changed

+62
-17
lines changed

2 files changed

+62
-17
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1046,11 +1046,15 @@ def scalar_solve_to_division(fgraph, node):
10461046
if not all(a.broadcastable[-2:]):
10471047
return None
10481048

1049+
if core_op.b_ndim == 1:
1050+
# Convert b to a column matrix
1051+
b = b[..., None]
1052+
10491053
# Special handling for different types of solve
10501054
match core_op:
10511055
case SolveTriangular():
10521056
# Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1
1053-
new_out = b / a if not core_op.unit_diagonal else b
1057+
new_out = b / a if not core_op.unit_diagonal else pt.second(a, b)
10541058
case CholeskySolve():
10551059
new_out = b / a**2
10561060
case Solve():
@@ -1061,6 +1065,7 @@ def scalar_solve_to_division(fgraph, node):
10611065
)
10621066

10631067
if core_op.b_ndim == 1:
1068+
# Squeeze away the column dimension added earlier
10641069
new_out = new_out.squeeze(-1)
10651070

10661071
copy_stack_trace(old_out, new_out)

tests/tensor/rewriting/test_linalg.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor import tensor as pt
1111
from pytensor.compile import get_default_mode
1212
from pytensor.configdefaults import config
13+
from pytensor.graph import ancestors
1314
from pytensor.graph.rewriting.utils import rewrite_graph
1415
from pytensor.tensor import swapaxes
1516
from pytensor.tensor.blockwise import Blockwise
@@ -989,34 +990,73 @@ def test_slogdet_specialization():
989990

990991

991992
@pytest.mark.parametrize(
992-
"Op, fn",
993+
"a_batch_shape", [(), (5,)], ids=lambda x: f"a_batch_shape={x}"
994+
)
995+
@pytest.mark.parametrize(
996+
"b_batch_shape", [(), (5,)], ids=lambda x: f"b_batch_shape={x}"
997+
)
998+
@pytest.mark.parametrize("b_ndim", (1, 2), ids=lambda x: f"b_ndim={x}")
999+
@pytest.mark.parametrize(
1000+
"op, fn, extra_kwargs",
9931001
[
994-
(Solve, pt.linalg.solve),
995-
(SolveTriangular, pt.linalg.solve_triangular),
996-
(CholeskySolve, pt.linalg.cho_solve),
1002+
(Solve, pt.linalg.solve, {}),
1003+
(SolveTriangular, pt.linalg.solve_triangular, {}),
1004+
(SolveTriangular, pt.linalg.solve_triangular, {"unit_diagonal": True}),
1005+
(CholeskySolve, pt.linalg.cho_solve, {}),
9971006
],
9981007
)
999-
def test_scalar_solve_to_division_rewrite(Op, fn):
1000-
rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite")))
1008+
def test_scalar_solve_to_division_rewrite(
1009+
op, fn, extra_kwargs, b_ndim, a_batch_shape, b_batch_shape
1010+
):
1011+
def solve_op_in_graph(graph):
1012+
return any(
1013+
isinstance(var.owner.op, SolveBase)
1014+
or (
1015+
isinstance(var.owner.op, Blockwise)
1016+
and isinstance(var.owner.op.core_op, SolveBase)
1017+
)
1018+
for var in ancestors(graph)
1019+
if var.owner
1020+
)
1021+
1022+
rng = np.random.default_rng(
1023+
[
1024+
sum(map(ord, "scalar_solve_to_division_rewrite")),
1025+
b_ndim,
1026+
*a_batch_shape,
1027+
1,
1028+
*b_batch_shape,
1029+
]
1030+
)
10011031

1002-
a = pt.dmatrix("a", shape=(1, 1))
1003-
b = pt.dvector("b")
1032+
a = pt.tensor("a", shape=(*a_batch_shape, 1, 1), dtype="float64")
1033+
b = pt.tensor("b", shape=(*b_batch_shape, *([None] * b_ndim)), dtype="float64")
10041034

1005-
if Op is CholeskySolve:
1035+
if op is CholeskySolve:
10061036
# cho_solve expects a tuple (c, lower) as the first input
1007-
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1)
1037+
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=b_ndim, **extra_kwargs)
10081038
else:
1009-
c = fn(a, b, b_ndim=1)
1039+
c = fn(a, b, b_ndim=b_ndim, **extra_kwargs)
10101040

1041+
assert solve_op_in_graph([c])
10111042
f = function([a, b], c, mode="FAST_RUN")
1012-
nodes = f.maker.fgraph.apply_nodes
1043+
assert not solve_op_in_graph(f.maker.fgraph.outputs)
1044+
1045+
a_val = rng.normal(size=(*a_batch_shape, 1, 1)).astype(pytensor.config.floatX)
1046+
b_core_shape = (1, 5) if b_ndim == 2 else (1,)
1047+
b_val = rng.normal(size=(*b_batch_shape, *b_core_shape)).astype(
1048+
pytensor.config.floatX
1049+
)
10131050

1014-
assert not any(isinstance(node.op, Op) for node in nodes)
1051+
if op is CholeskySolve:
1052+
# Avoid sign ambiguity in solve
1053+
a_val = a_val**2
10151054

1016-
a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX)
1017-
b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX)
1055+
if extra_kwargs.get("unit_diagonal", False):
1056+
a_val = np.ones_like(a_val)
10181057

1019-
c_val = np.linalg.solve(a_val, b_val)
1058+
signature = "(n,m),(m)->(n)" if b_ndim == 1 else "(n,m),(m,k)->(n,k)"
1059+
c_val = np.vectorize(np.linalg.solve, signature=signature)(a_val, b_val)
10201060
np.testing.assert_allclose(
10211061
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
10221062
)

0 commit comments

Comments
 (0)