|
10 | 10 | from pytensor import tensor as pt
|
11 | 11 | from pytensor.compile import get_default_mode
|
12 | 12 | from pytensor.configdefaults import config
|
| 13 | +from pytensor.graph import ancestors |
13 | 14 | from pytensor.graph.rewriting.utils import rewrite_graph
|
14 | 15 | from pytensor.tensor import swapaxes
|
15 | 16 | from pytensor.tensor.blockwise import Blockwise
|
@@ -989,34 +990,73 @@ def test_slogdet_specialization():
|
989 | 990 |
|
990 | 991 |
|
991 | 992 | @pytest.mark.parametrize(
|
992 |
| - "Op, fn", |
| 993 | + "a_batch_shape", [(), (5,)], ids=lambda x: f"a_batch_shape={x}" |
| 994 | +) |
| 995 | +@pytest.mark.parametrize( |
| 996 | + "b_batch_shape", [(), (5,)], ids=lambda x: f"b_batch_shape={x}" |
| 997 | +) |
| 998 | +@pytest.mark.parametrize("b_ndim", (1, 2), ids=lambda x: f"b_ndim={x}") |
| 999 | +@pytest.mark.parametrize( |
| 1000 | + "op, fn, extra_kwargs", |
993 | 1001 | [
|
994 |
| - (Solve, pt.linalg.solve), |
995 |
| - (SolveTriangular, pt.linalg.solve_triangular), |
996 |
| - (CholeskySolve, pt.linalg.cho_solve), |
| 1002 | + (Solve, pt.linalg.solve, {}), |
| 1003 | + (SolveTriangular, pt.linalg.solve_triangular, {}), |
| 1004 | + (SolveTriangular, pt.linalg.solve_triangular, {"unit_diagonal": True}), |
| 1005 | + (CholeskySolve, pt.linalg.cho_solve, {}), |
997 | 1006 | ],
|
998 | 1007 | )
|
999 |
| -def test_scalar_solve_to_division_rewrite(Op, fn): |
1000 |
| - rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite"))) |
| 1008 | +def test_scalar_solve_to_division_rewrite( |
| 1009 | + op, fn, extra_kwargs, b_ndim, a_batch_shape, b_batch_shape |
| 1010 | +): |
| 1011 | + def solve_op_in_graph(graph): |
| 1012 | + return any( |
| 1013 | + isinstance(var.owner.op, SolveBase) |
| 1014 | + or ( |
| 1015 | + isinstance(var.owner.op, Blockwise) |
| 1016 | + and isinstance(var.owner.op.core_op, SolveBase) |
| 1017 | + ) |
| 1018 | + for var in ancestors(graph) |
| 1019 | + if var.owner |
| 1020 | + ) |
| 1021 | + |
| 1022 | + rng = np.random.default_rng( |
| 1023 | + [ |
| 1024 | + sum(map(ord, "scalar_solve_to_division_rewrite")), |
| 1025 | + b_ndim, |
| 1026 | + *a_batch_shape, |
| 1027 | + 1, |
| 1028 | + *b_batch_shape, |
| 1029 | + ] |
| 1030 | + ) |
1001 | 1031 |
|
1002 |
| - a = pt.dmatrix("a", shape=(1, 1)) |
1003 |
| - b = pt.dvector("b") |
| 1032 | + a = pt.tensor("a", shape=(*a_batch_shape, 1, 1), dtype="float64") |
| 1033 | + b = pt.tensor("b", shape=(*b_batch_shape, *([None] * b_ndim)), dtype="float64") |
1004 | 1034 |
|
1005 |
| - if Op is CholeskySolve: |
| 1035 | + if op is CholeskySolve: |
1006 | 1036 | # cho_solve expects a tuple (c, lower) as the first input
|
1007 |
| - c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1) |
| 1037 | + c = fn((pt.linalg.cholesky(a), True), b, b_ndim=b_ndim, **extra_kwargs) |
1008 | 1038 | else:
|
1009 |
| - c = fn(a, b, b_ndim=1) |
| 1039 | + c = fn(a, b, b_ndim=b_ndim, **extra_kwargs) |
1010 | 1040 |
|
| 1041 | + assert solve_op_in_graph([c]) |
1011 | 1042 | f = function([a, b], c, mode="FAST_RUN")
|
1012 |
| - nodes = f.maker.fgraph.apply_nodes |
| 1043 | + assert not solve_op_in_graph(f.maker.fgraph.outputs) |
| 1044 | + |
| 1045 | + a_val = rng.normal(size=(*a_batch_shape, 1, 1)).astype(pytensor.config.floatX) |
| 1046 | + b_core_shape = (1, 5) if b_ndim == 2 else (1,) |
| 1047 | + b_val = rng.normal(size=(*b_batch_shape, *b_core_shape)).astype( |
| 1048 | + pytensor.config.floatX |
| 1049 | + ) |
1013 | 1050 |
|
1014 |
| - assert not any(isinstance(node.op, Op) for node in nodes) |
| 1051 | + if op is CholeskySolve: |
| 1052 | + # Avoid sign ambiguity in solve |
| 1053 | + a_val = a_val**2 |
1015 | 1054 |
|
1016 |
| - a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX) |
1017 |
| - b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX) |
| 1055 | + if extra_kwargs.get("unit_diagonal", False): |
| 1056 | + a_val = np.ones_like(a_val) |
1018 | 1057 |
|
1019 |
| - c_val = np.linalg.solve(a_val, b_val) |
| 1058 | + signature = "(n,m),(m)->(n)" if b_ndim == 1 else "(n,m),(m,k)->(n,k)" |
| 1059 | + c_val = np.vectorize(np.linalg.solve, signature=signature)(a_val, b_val) |
1020 | 1060 | np.testing.assert_allclose(
|
1021 | 1061 | f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
|
1022 | 1062 | )
|
0 commit comments