-
Notifications
You must be signed in to change notification settings - Fork 53
Open
Description
Context
When using @qjit
with finite shots, qml.sample()
measurements fail with a shape canonicalization error involving OutDBIdx
references, while qml.expval()
measurements work correctly. This occurs in the dynamic_one_shot
transformation that's automatically applied when using finite shots and in the case of wires
is not set to qml.qnode
.
Reproduction
❌ Failing Case (qml.sample())
import pennylane as qml
from catalyst import qjit
backend = "lightning.qubit"
@qjit
@qml.qnode(qml.device(backend, shots=10), mcm_method='one-shot')
def circuit():
qml.RX(0.0, wires=3)
return qml.sample()
circuit()
✅ Working Case (qml.expval())
@qjit
@qml.qnode(qml.device(backend, shots=10), mcm_method='one-shot')
def circuit():
qml.RX(0.0, wires=3)
return qml.expval(qml.PauliZ(0)) # This works
circuit()
Why expval() works
- No
OutDBIdx
references are created, the shape for expval(...) case isShapedType
which can be handled byjnp.zeros()
Full Stack Trace
Traceback (most recent call last):
File "/path/to/test.py", line 7, in <module>
@qjit
^^^^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/path/to/catalyst/frontend/catalyst/jit.py", line 502, in qjit
return QJIT(fn, CompileOptions(**kwargs))
File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
output = func(*args, **kwargs)
File "/path/to/catalyst/frontend/catalyst/jit.py", line 565, in __init__
self.aot_compile()
~~~~~~~~~~~~~~~~^^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/path/to/catalyst/frontend/catalyst/jit.py", line 618, in aot_compile
self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
~~~~~~~~~~~~^
self.user_sig or ()
^^^^^^^^^^^^^^^^^^^
)
^
File "/path/to/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
return fn(*args, **kwargs)
File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/path/to/catalyst/frontend/catalyst/jit.py", line 759, in capture
jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
~~~~~~~~~~~~~~^
self.user_function, static_argnums, abstracted_axes, full_sig, kwargs, dbg
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/path/to/catalyst/frontend/catalyst/jax_tracer.py", line 613, in trace_to_jaxpr
jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/path/to/catalyst/frontend/catalyst/jax_extras/tracing.py", line 499, in make_jaxpr_f
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
~~~~~~~~~~~~~~~~~~~~~~~^^^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
return func(*args, **kwargs)
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers)
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
return self.f_transformed(*args, **kwargs)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
ans = f(*py_args, **py_kwargs)
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
ans = _fun(*args, **kwargs)
File "/path/to/catalyst/frontend/catalyst/jit.py", line 749, in closure
return QFunc.__call__(
~~~~~~~~~~~~~~^
qnode,
^^^^^^
*args,
^^^^^^
**dict(params, **kwargs),
^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/path/to/catalyst/frontend/catalyst/qfunc.py", line 143, in __call__
return Function(dynamic_one_shot(self, mcm_config=mcm_config))(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/path/to/catalyst/frontend/catalyst/jax_tracer.py", line 181, in __call__
jaxpr, _, out_tree = make_jaxpr2(
~~~~~~~~~~~~
self.fn,
~~~~~~~~
debug_info=kwargs.pop("debug_info", jdb("Function", self.fn, args, kwargs)),
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
)(*args, **kwargs)
~^^^^^^^^^^^^^^^^^
File "/path/to/catalyst/frontend/catalyst/jax_extras/tracing.py", line 499, in make_jaxpr_f
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
~~~~~~~~~~~~~~~~~~~~~~~^^^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
return func(*args, **kwargs)
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers)
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
return self.f_transformed(*args, **kwargs)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
ans = f(*py_args, **py_kwargs)
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
ans = _fun(*args, **kwargs)
File "/path/to/catalyst/frontend/catalyst/qfunc.py", line 286, in one_shot_wrapper
results = catalyst.vmap(wrap_single_shot_qnode)(arg_vmap)
File "/path/to/catalyst/frontend/catalyst/api_extensions/function_maps.py", line 235, in __call__
init_result_flat = [jnp.zeros(shape=shape.shape, dtype=shape.dtype) for shape, _ in shapes]
~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/numpy/array_creation.py", line 82, in zeros
shape = canonicalize_shape(shape)
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/numpy/array_creation.py", line 45, in canonicalize_shape
return core.canonicalize_shape(shape, context)
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/core.py", line 1864, in canonicalize_shape
raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of integer scalars, got (1, OutDBIdx(val=0))
Metadata
Metadata
Assignees
Labels
No labels