Skip to content

qml.sample() fails with OutDBIdx shape canonicalization error in dynamic one-shot context #1949

@rniczh

Description

@rniczh

Context

When using @qjit with finite shots, qml.sample() measurements fail with a shape canonicalization error involving OutDBIdx references, while qml.expval() measurements work correctly. This occurs in the dynamic_one_shot transformation that's automatically applied when using finite shots and in the case of wires is not set to qml.qnode.

Reproduction

❌ Failing Case (qml.sample())

import pennylane as qml
from catalyst import qjit

backend = "lightning.qubit"

@qjit
@qml.qnode(qml.device(backend, shots=10), mcm_method='one-shot')
def circuit():
    qml.RX(0.0, wires=3)
    return qml.sample()

circuit()

✅ Working Case (qml.expval())

@qjit
@qml.qnode(qml.device(backend, shots=10), mcm_method='one-shot')
def circuit():
    qml.RX(0.0, wires=3)
    return qml.expval(qml.PauliZ(0))  # This works

circuit()

Why expval() works

  • No OutDBIdx references are created, the shape for expval(...) case is ShapedType which can be handled by jnp.zeros()

Full Stack Trace

Traceback (most recent call last):
  File "/path/to/test.py", line 7, in <module>
    @qjit
     ^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 502, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 565, in __init__
    self.aot_compile()
    ~~~~~~~~~~~~~~~~^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 618, in aot_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
                                                              ~~~~~~~~~~~~^
        self.user_sig or ()
        ^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/path/to/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 759, in capture
    jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
                                        ~~~~~~~~~~~~~~^
        self.user_function, static_argnums, abstracted_axes, full_sig, kwargs, dbg
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jax_tracer.py", line 613, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
                                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/catalyst/frontend/catalyst/jax_extras/tracing.py", line 499, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
                              ~~~~~~~~~~~~~~~~~~~~~~~^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
    return func(*args, **kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 749, in closure
    return QFunc.__call__(
           ~~~~~~~~~~~~~~^
        qnode,
        ^^^^^^
        *args,
        ^^^^^^
        **dict(params, **kwargs),
        ^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/qfunc.py", line 143, in __call__
    return Function(dynamic_one_shot(self, mcm_config=mcm_config))(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jax_tracer.py", line 181, in __call__
    jaxpr, _, out_tree = make_jaxpr2(
                         ~~~~~~~~~~~~
        self.fn,
        ~~~~~~~~
        debug_info=kwargs.pop("debug_info", jdb("Function", self.fn, args, kwargs)),
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(*args, **kwargs)
    ~^^^^^^^^^^^^^^^^^
  File "/path/to/catalyst/frontend/catalyst/jax_extras/tracing.py", line 499, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
                              ~~~~~~~~~~~~~~~~~~~~~~~^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
    return func(*args, **kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/qfunc.py", line 286, in one_shot_wrapper
    results = catalyst.vmap(wrap_single_shot_qnode)(arg_vmap)
  File "/path/to/catalyst/frontend/catalyst/api_extensions/function_maps.py", line 235, in __call__
    init_result_flat = [jnp.zeros(shape=shape.shape, dtype=shape.dtype) for shape, _ in shapes]
                        ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/numpy/array_creation.py", line 82, in zeros
    shape = canonicalize_shape(shape)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/numpy/array_creation.py", line 45, in canonicalize_shape
    return core.canonicalize_shape(shape, context)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/core.py", line 1864, in canonicalize_shape
    raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of integer scalars, got (1, OutDBIdx(val=0))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions