From 4a8c73a8fb6e0e72ed5885e707066c072d07f3ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Wed, 13 Aug 2025 17:30:15 +0100 Subject: [PATCH 01/11] feat: Add a LoadedParam struct for the pytket decoder --- tket/src/serialize/pytket/decoder/param.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tket/src/serialize/pytket/decoder/param.rs b/tket/src/serialize/pytket/decoder/param.rs index cae1eed45..b9d5cda5a 100644 --- a/tket/src/serialize/pytket/decoder/param.rs +++ b/tket/src/serialize/pytket/decoder/param.rs @@ -48,10 +48,6 @@ impl LoadedParameter { } /// Returns the hugr type for the parameter. - #[expect( - dead_code, - reason = "Temporarily unused while we refactor the pytket decoder" - )] pub fn wire_type(&self) -> &Type { static FLOAT_TYPE: LazyLock = LazyLock::new(float64_type); static ROTATION_TYPE: LazyLock = LazyLock::new(rotation_type); From 4066812a4cfc12748922071becdcdc28a2649dc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Wed, 13 Aug 2025 18:00:21 +0100 Subject: [PATCH 02/11] Delete the previous LoadedParameter --- tket/src/serialize/pytket/decoder/param.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tket/src/serialize/pytket/decoder/param.rs b/tket/src/serialize/pytket/decoder/param.rs index b9d5cda5a..cae1eed45 100644 --- a/tket/src/serialize/pytket/decoder/param.rs +++ b/tket/src/serialize/pytket/decoder/param.rs @@ -48,6 +48,10 @@ impl LoadedParameter { } /// Returns the hugr type for the parameter. + #[expect( + dead_code, + reason = "Temporarily unused while we refactor the pytket decoder" + )] pub fn wire_type(&self) -> &Type { static FLOAT_TYPE: LazyLock = LazyLock::new(float64_type); static ROTATION_TYPE: LazyLock = LazyLock::new(rotation_type); From 5e26e5c233d2dddc4189a4c921a27382a0470989 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Wed, 13 Aug 2025 17:26:50 +0100 Subject: [PATCH 03/11] feat: Define a wire tracker for the new pytket decoder --- tket-py/src/circuit.rs | 6 + tket/src/serialize/pytket.rs | 168 ++++ .../serialize/pytket/decoder/tracked_elem.rs | 151 +++ tket/src/serialize/pytket/decoder/wires.rs | 861 ++++++++++++++++++ 4 files changed, 1186 insertions(+) create mode 100644 tket/src/serialize/pytket/decoder/tracked_elem.rs create mode 100644 tket/src/serialize/pytket/decoder/wires.rs diff --git a/tket-py/src/circuit.rs b/tket-py/src/circuit.rs index 15153ff1a..9e2aa73d6 100644 --- a/tket-py/src/circuit.rs +++ b/tket-py/src/circuit.rs @@ -81,6 +81,12 @@ create_py_exception!( "Error type for the conversion between tket and tket1 operations." ); +create_py_exception!( + tket::serialize::pytket::Tk1DecodeError, + PyTK1DecodeError, + "Error type for the conversion between tket1 and tket operations." +); + /// Run the validation checks on a circuit. #[pyfunction] pub fn validate_circuit(c: &Bound) -> PyResult<()> { diff --git a/tket/src/serialize/pytket.rs b/tket/src/serialize/pytket.rs index c994d4093..8dbbb347b 100644 --- a/tket/src/serialize/pytket.rs +++ b/tket/src/serialize/pytket.rs @@ -279,6 +279,174 @@ impl Tk1ConvertError { } } +/// Error type for conversion between tket2 ops and pytket operations. +#[derive(derive_more::Debug, Display, Error)] +#[non_exhaustive] +pub enum Tk1DecodeError { + /// The pytket circuit uses multi-indexed registers. + // + // This could be supported in the future, if there is a need for it. + #[display("Register {register} in the circuit has multiple indices. Tket2 does not support multi-indexed registers.")] + MultiIndexedRegister { + /// The register name. + register: String, + }, + /// Found an unexpected register name. + #[display("Found an unknown qubit register name: {register}.")] + UnknownQubitRegister { + /// The unknown register name. + register: String, + }, + /// Found an unexpected bit register name. + #[display("Found an unknown bit register name: {register}.")] + UnknownBitRegister { + /// The unknown register name. + register: String, + }, + /// The given signature to use for the HUGR's input wires is not compatible with the number of qubits and bits in the pytket circuit. + /// + /// The expected number of qubits and bits may be different depending on the [`PytketTypeTranslator`][extension::PytketTypeTranslator]s used in the decoder config. + #[display( + "The given input types {input_types} to use for the HUGR's input wires are not compatible with the number of qubits and bits in the pytket circuit. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits.", + input_types = input_types.iter().join(", "), + )] + InvalidInputSignature { + /// The given input types. + input_types: Vec, + /// The expected number of qubits in the signature. + expected_qubits: usize, + /// The expected number of bits in the signature. + expected_bits: usize, + /// The number of qubits in the pytket circuit. + circ_qubits: usize, + /// The number of bits in the pytket circuit. + circ_bits: usize, + }, + /// The signature to use for the HUGR's output wires is not compatible with the number of qubits and bits in the pytket circuit. + /// + /// We don't do any kind of type conversion, so this depends solely on the last operation to update each register. + #[display( + "The expected output types {expected_types} are not compatible with the actual output types {actual_types}, obtained from decoding the pytket circuit.", + expected_types = expected_types.iter().join(", "), + actual_types = actual_types.iter().join(", "), + )] + InvalidOutputSignature { + /// The expected types of the input wires. + expected_types: Vec, + /// The actual types of the input wires. + actual_types: Vec, + }, + /// A pytket operation had some input registers that couldn't be mapped to hugr wires. + // + // Some of this errors will be avoided in the future once we are able to decompose complex types automatically. + #[display( + "Could not find a wire with the required qubit arguments [{qubit_args:?}] and bit arguments [{bit_args:?}] for operation {operation}.", + qubit_args = qubit_args.iter().join(", "), + bit_args = bit_args.iter().join(", "), + )] + ArgumentCouldNotBeMapped { + /// The operation type that was being decoded. + operation: String, + /// The qubit arguments that couldn't be mapped. + qubit_args: Vec, + /// The bit arguments that couldn't be mapped. + bit_args: Vec, + }, + /// Found an unexpected number of input wires when decoding an operation. + #[display( + "Expected {expected_values} input value wires{expected_types} and {expected_params} input parameters when decoding a {operation}, but found {actual_values} values{actual_types} and {actual_params} parameters.", + expected_types = match expected_types { + None => "".to_string(), + Some(tys) => format!(" with types [{}]", tys.iter().join(", ")), + }, + actual_types = match actual_types { + None => "".to_string(), + Some(tys) => format!(" with types [{}]", tys.iter().join(", ")), + }, + )] + UnexpectedInputWires { + /// The expected amount of input wires. + expected_values: usize, + /// The expected amount of input parameters. + expected_params: usize, + /// The actual amount of input wires. + actual_values: usize, + /// The actual amount of input parameters. + actual_params: usize, + /// The expected types of the input wires. + expected_types: Option>, + /// The actual types of the input wires. + actual_types: Option>, + /// The operation type that was being decoded. + operation: String, + }, + /// Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires. + #[display( + "Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits in the node outputs." + )] + UnexpectedNodeOutput { + /// The expected number of qubits. + expected_qubits: usize, + /// The expected number of bits. + expected_bits: usize, + /// The number of qubits in HUGR node outputs. + circ_qubits: usize, + /// The number of bits in HUGR node output. + circ_bits: usize, + }, + /// Custom user-defined error raised while encoding an operation. + #[display("Error while decoding operation: {msg}")] + CustomError { + /// The custom error message + msg: String, + }, + /// Input parameter was defined multiple times. + #[display("Parameter {param} was defined multiple times in the input signature")] + DuplicatedParameter { + /// The parameter name. + param: String, + }, + /// Not enough parameter names given for the input signature. + #[display("Tried to initialize a pytket circuit decoder with {num_params_given} given parameter names, but more were required by the input signature.")] + MissingParametersInInput { + /// The number of parameters given. + num_params_given: usize, + }, + /// We don't support complex types containing parameters in the input. + // + // This restriction may be relaxed in the future. + #[display("Complex type {ty} contains {num_params} inside it. We only support input parameters in standalone 'float' or 'rotation'-typed wires.")] + UnsupportedParametersInInput { + /// The type that contains the parameters. + ty: String, + /// The number of parameters in the type. + num_params: usize, + }, +} + +impl Tk1DecodeError { + /// Create a new error with a custom message. + pub fn custom(msg: impl ToString) -> Self { + Self::CustomError { + msg: msg.to_string(), + } + } + + /// Create an error for an unknown qubit register. + pub fn unknown_qubit_reg(register: &tket_json_rs::register::ElementId) -> Self { + Self::UnknownQubitRegister { + register: register.to_string(), + } + } + + /// Create an error for an unknown bit register. + pub fn unknown_bit_reg(register: &tket_json_rs::register::ElementId) -> Self { + Self::UnknownBitRegister { + register: register.to_string(), + } + } +} + /// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map, /// avoiding string and vector clones on lookup. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] diff --git a/tket/src/serialize/pytket/decoder/tracked_elem.rs b/tket/src/serialize/pytket/decoder/tracked_elem.rs new file mode 100644 index 000000000..4b914da1b --- /dev/null +++ b/tket/src/serialize/pytket/decoder/tracked_elem.rs @@ -0,0 +1,151 @@ +//! Pytket qubit and bit elements that we track during decoding. + +use std::sync::{Arc, LazyLock}; + +use hugr::extension::prelude::{bool_t, qb_t}; +use hugr::types::Type; +use tket_json_rs::register::ElementId as PytketRegister; + +/// An internal lightweight identifier for a [`TrackedQubit`] in the decoder. +#[derive(Clone, Copy, Debug, derive_more::Display, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub(super) struct TrackedQubitId(#[display(transparent)] pub usize); + +/// An internal lightweight identifier for a [`TrackedBit`] in the decoder. +#[derive(Clone, Copy, Debug, derive_more::Display, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub(super) struct TrackedBitId(#[display(transparent)] pub usize); + +/// An identifier for a pytket qubit register in the data carried by a wire. +/// +/// After a pytket circuit assigns a new value to the register, older +/// [`TrackedQubit`]s referring to it become _outdated_. +/// +/// Outdated values no longer correspond to a pytket circuit register, but they +/// can still be found in the wires of the hugr being extracted. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TrackedQubit { + outdated: bool, + reg: Arc, +} + +/// An identifier for a pytket bit register in the data carried by a wire. +/// +/// After a pytket circuit assigns a new value to the register, older +/// [`TrackedBit`]s referring to it become _outdated_. +/// +/// Outdated values no longer correspond to a pytket circuit register, but they +/// can still be found in the wires of the hugr being extracted. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct TrackedBit { + outdated: bool, + reg: Arc, +} + +impl TrackedQubit { + /// Return a new tracked qubit. + pub(super) fn new(reg: Arc) -> Self { + Self { + outdated: false, + reg, + } + } + + /// Returns the pytket register for this tracked element. + pub fn pytket_register(&self) -> &PytketRegister { + &self.reg + } + + /// Returns the pytket register for this tracked element. + pub fn pytket_register_arc(&self) -> Arc { + self.reg.clone() + } + + /// Returns the type of the element. + pub fn ty(&self) -> Arc { + static QUBIT_TYPE: LazyLock> = LazyLock::new(|| qb_t().into()); + QUBIT_TYPE.clone() + } + + /// Returns `true` if the element has been overwritten by a new value. + pub fn is_outdated(&self) -> bool { + self.outdated + } + + /// Mark the element as outdated. + pub(super) fn mark_outdated(&mut self) { + self.outdated = true; + } +} + +impl TrackedBit { + /// Returns a new tracked bit. + pub(super) fn new(reg: Arc) -> Self { + Self { + outdated: false, + reg, + } + } + + /// Returns the pytket register for this tracked element. + pub fn pytket_register(&self) -> &PytketRegister { + &self.reg + } + + /// Returns the pytket register for this tracked element. + pub fn pytket_register_arc(&self) -> Arc { + self.reg.clone() + } + + /// Returns the type of the element. + pub fn ty(&self) -> Arc { + static BOOL_TYPE: LazyLock> = LazyLock::new(|| bool_t().into()); + BOOL_TYPE.clone() + } + + /// Returns `true` if the element has been overwritten by a new value. + pub fn is_outdated(&self) -> bool { + self.outdated + } + + /// Mark the element as outdated. + pub(super) fn mark_outdated(&mut self) { + self.outdated = true; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hugr::extension::prelude::{bool_t, qb_t}; + use hugr::types::Type; + use rstest::rstest; + use std::sync::Arc; + use tket_json_rs::register::ElementId; + + #[rstest] + fn tracked_qubit_basic_behaviour() { + let reg = Arc::new(ElementId("q".to_string(), vec![0])); + let mut tq = TrackedQubit::new(reg.clone()); + + assert!(!tq.is_outdated()); + assert_eq!(tq.pytket_register(), &*reg); + assert_eq!(tq.pytket_register_arc(), reg); + assert_eq!(&*tq.ty(), &Type::from(qb_t())); + + tq.mark_outdated(); + assert!(tq.is_outdated()); + } + + #[rstest] + fn tracked_bit_basic_behaviour() { + let reg = Arc::new(ElementId("c".to_string(), vec![1])); + let mut tb = TrackedBit::new(reg.clone()); + + assert!(!tb.is_outdated()); + assert_eq!(tb.pytket_register(), &*reg); + assert_eq!(tb.pytket_register_arc(), reg); + assert_eq!(&*tb.ty(), &Type::from(bool_t())); + + tb.mark_outdated(); + assert!(tb.is_outdated()); + } +} diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs new file mode 100644 index 000000000..ecebcb51d --- /dev/null +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -0,0 +1,861 @@ +//! Structures to keep track of pytket [`ElementId`][tket_json_rs::register::ElementId]s and +//! their correspondence to wires in the hugr being defined. + +use std::collections::VecDeque; +use std::sync::Arc; + +use hugr::builder::{Dataflow as _, FunctionBuilder}; +use hugr::ops::Value; +use hugr::std_extensions::arithmetic::float_types::ConstF64; +use hugr::types::Type; +use hugr::{Hugr, Wire}; +use indexmap::{IndexMap, IndexSet}; +use itertools::Itertools; +use tket_json_rs::register::ElementId as PytketRegister; + +use crate::extension::rotation::rotation_type; +use crate::serialize::pytket::decoder::param::parser::{parse_pytket_param, PytketParam}; +use crate::serialize::pytket::decoder::{ + LoadedParameter, TrackedBit, TrackedBitId, TrackedQubit, TrackedQubitId, +}; +use crate::serialize::pytket::extension::RegisterCount; +use crate::serialize::pytket::{RegisterHash, Tk1DecodeError}; +use crate::symbolic_constant_op; + +/// Tracked data for a wire in [`TrackedWires`]. +#[derive(Debug, Clone, PartialEq)] +pub struct WireData { + /// The identifier in the hugr. + wire: Wire, + /// The type of the wire. + ty: Arc, + /// List of pytket qubit arguments corresponding to this wire. + qubits: Vec, + /// List of pytket bit arguments corresponding to this wire. + bits: Vec, +} + +impl WireData { + /// The wire identifier. + pub fn wire(&self) -> Wire { + self.wire + } + + /// The HUGR type for the wire. + pub fn ty(&self) -> &Type { + &self.ty + } + + /// The HUGR type for the wire. + pub fn ty_arc(&self) -> Arc { + self.ty.clone() + } + + /// The pytket qubit arguments corresponding to this wire. + pub fn qubits<'d>( + &'d self, + wire_tracker: &'d WireTracker, + ) -> impl Iterator + 'd { + self.qubits + .iter() + .map(move |elem_id| wire_tracker.get_qubit(*elem_id)) + .cloned() + } + + /// The pytket bit arguments corresponding to this wire. + pub fn bits<'d>( + &'d self, + wire_tracker: &'d WireTracker, + ) -> impl Iterator + 'd { + self.bits + .iter() + .map(move |elem_id| wire_tracker.get_bit(*elem_id)) + .cloned() + } + + /// Returns the number of qubits carried by this wire. + pub fn num_qubits(&self) -> usize { + self.qubits.len() + } + + /// Returns the number of bits carried by this wire. + pub fn num_bits(&self) -> usize { + self.bits.len() + } + + /// Returns the number of tracked elements in this wire. + pub fn num_args(&self) -> usize { + self.num_qubits() + self.num_bits() + } +} + +/// Tracked wires to a pytket operation. +#[derive(Debug, Clone)] +pub struct TrackedWires { + /// Computed list of wires corresponding to the arguments, + /// along with their types. + value_wires: Vec, + /// List of wires corresponding to the parameters. + parameter_wires: Vec>, +} + +impl TrackedWires { + /// Retrieve the wire data at the given index. + /// + /// Panics if the index is out of bounds. See [`TrackedWires::len`]. + #[inline] + #[must_use] + pub fn value_wire(&self, idx: usize) -> &WireData { + self.value_wires.get(idx).unwrap_or_else(|| { + panic!( + "Cannot get wire data at index {idx}, only {} wires are tracked", + self.value_wires.len() + ) + }) + } + + /// Return the number of value wires tracked. + #[inline] + #[must_use] + pub fn value_count(&self) -> usize { + self.value_wires.len() + } + + /// Return the number of parameter wires tracked. + #[inline] + #[must_use] + pub fn parameter_count(&self) -> usize { + self.parameter_wires.len() + } + + /// Return the number of wires tracked. + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.value_wires.len() + self.parameter_wires.len() + } + + /// Return whether there are no tracked wires. + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.value_wires.is_empty() && self.parameter_wires.is_empty() + } + + /// Return an iterator over the wires and their types. + /// + /// This returns the wires as-is, without any additional conversions. + /// If you need to retrieve a specific wire type, use TODO + #[inline] + pub fn iter_values(&self) -> impl Iterator + Clone + '_ { + self.value_wires.iter() + } + + /// Return an iterator over the parameters. + #[inline] + pub fn iter_parameters(&self) -> impl Iterator + Clone + '_ { + self.parameter_wires.iter().map(|p| p.as_ref()) + } + + /// Returns the types of the value wires. + #[inline] + pub fn value_types(&self) -> impl Iterator + Clone + '_ { + self.value_wires.iter().map(|wd| wd.ty()) + } + + /// Returns the types of the parameter wires. + #[inline] + pub fn parameter_types(&self) -> impl Iterator + Clone + '_ { + self.parameter_wires.iter().map(|p| p.wire_type()) + } + + /// Returns the wire types in this tracked wires. + #[inline] + pub fn wire_types(&self) -> impl Iterator + Clone + '_ { + self.value_types().chain(self.parameter_types()) + } + + /// Returns the tracked qubit elements in the set of wires. + #[inline] + pub fn qubits<'d>( + &'d self, + wire_tracker: &'d WireTracker, + ) -> impl Iterator + 'd { + self.value_wires + .iter() + .flat_map(move |wd| wd.qubits(wire_tracker)) + } + + /// Returns the tracked qubit elements in the set of wires as an array. + #[inline] + pub fn qubits_arr( + &self, + wire_tracker: &WireTracker, + ) -> Option<[TrackedQubit; N]> { + self.qubits(wire_tracker).collect_array() + } + + /// Returns the tracked bit elements in the set of wires. + #[inline] + pub fn bits<'d>( + &'d self, + wire_tracker: &'d WireTracker, + ) -> impl Iterator + 'd { + self.value_wires + .iter() + .flat_map(move |wd| wd.bits(wire_tracker)) + } + + /// Return the tracked value wires in this tracked wires. + #[inline] + pub fn value_wires(&self) -> impl Iterator + Clone + '_ { + self.value_wires.iter().map(|wd| wd.wire()) + } + + /// Return the tracked parameter wires in this tracked wires. + #[inline] + pub fn parameter_wires(&self) -> impl Iterator + Clone + '_ { + self.parameter_wires.iter().map(|p| p.wire) + } + + /// Returns the wires in this tracked wires. + #[inline] + pub fn wires(&self) -> impl Iterator + Clone + '_ { + self.value_wires().chain(self.parameter_wires()) + } + + /// Returns the wires in this tracked wires as an array of types. + /// + /// Returns an error if the number of wires is not equal to `N`. + /// + /// # Arguments + /// + /// * `operation` - The name of the operation being decoded, used for error reporting. + #[inline] + pub fn wires_arr(&self, operation: &str) -> Result<[Wire; N], Tk1DecodeError> { + let expected_values = N.saturating_sub(self.parameter_count()); + let expected_params = N - expected_values; + self.check_len(expected_values, expected_params, operation)?; + Ok(self + .wires() + .collect_array() + .expect("check_len should have failed")) + } + + /// Returns the amount of qubits, bits, and parameters carried by this tracked wires. + #[inline] + #[must_use] + pub fn register_count(&self) -> RegisterCount { + let mut counts: RegisterCount = self + .iter_values() + .map(|w| RegisterCount::new(w.num_qubits(), w.num_bits(), 0)) + .sum(); + counts.params += self.parameter_count(); + counts + } + + /// Checks that we have the expected number of wires, and returns an error otherwise. + /// + /// # Arguments + /// + /// * `expected_values` - The expected number of value wires. + /// * `expected_params` - The expected number of parameter wires. + /// * `operation` - The name of the operation being decoded, used for error reporting. + pub fn check_len( + &self, + expected_values: usize, + expected_params: usize, + operation: &str, + ) -> Result<(), Tk1DecodeError> { + if self.value_count() != expected_values || self.parameter_count() != expected_params { + let types = self.wire_types().map(|ty| ty.to_string()).collect_vec(); + Err(Tk1DecodeError::UnexpectedInputWires { + expected_values, + expected_params, + actual_values: self.value_count(), + actual_params: self.parameter_count(), + expected_types: None, + actual_types: Some(types), + operation: operation.to_string(), + }) + } else { + Ok(()) + } + } + + /// Checks that we have the expected wire types, and returns an error otherwise. + /// + /// # Arguments + /// + /// * `expected_values` - The expected types of the value wires. + /// * `expected_params` - The expected number of parameters. Note that these may be either `float` or `rotation`-typed. + /// Use [`LoadedParameter::with_type`] to cast them as needed. + /// * `operation` - The name of the operation being decoded, used for error reporting. + pub fn check_types( + &self, + expected_values: &[Type], + expected_params: usize, + operation: &str, + ) -> Result<(), Tk1DecodeError> { + let vals = expected_values.iter(); + if !itertools::equal(self.value_types(), vals) || self.parameter_count() != expected_params + { + let actual = self.value_types().collect_vec(); + Err(Tk1DecodeError::UnexpectedInputWires { + expected_values: expected_values.len(), + expected_params, + actual_values: self.value_count(), + actual_params: self.parameter_count(), + expected_types: Some( + expected_values + .iter() + .map(|ty| ty.to_string()) + .collect_vec(), + ), + actual_types: Some(actual.iter().map(|ty| ty.to_string()).collect_vec()), + operation: operation.to_string(), + }) + } else { + Ok(()) + } + } +} + +/// Tracker for wires added to a hugr. +/// +/// Keeps track of the wires added to the hugr, and the qubit/bit/parameters +/// that they contain. +/// +/// Wire may contain either a single [`LoadedParameter`] or a collection of +/// [`TrackedQubit`]s and [`TrackedBit`]s. Each tracked +/// element in a wire is said to be "up to date" if it is the latest reference +/// to that pytket register. Once the register is seen in the output of an +/// operation, all previous references to it become "outdated". +#[derive(Debug, Clone, Default)] +pub struct WireTracker { + /// A list of tracked wires, with their type and list of + /// tracked pytket elements and arguments. + wires: IndexMap, + /// The list of tracked qubit elements. + /// + /// Indexed by [`TrackedQubitId`]. + qubits: Vec, + /// The list of tracked bit elements. + /// + /// Indexed by [`TrackedBitId`]. + bits: Vec, + /// A map from pytket register hashes to the latest up-to-date [`TrackedQubit`] referencing it. + /// + /// The map keys are kept in the order they were defined in the circuit. + latest_qubit_tracker: IndexMap, + /// A map from pytket register hashes to the latest up-to-date [`TrackedBit`] referencing it. + /// + /// The map keys are kept in the order they were defined in the circuit. + latest_bit_tracker: IndexMap, + /// For each tracked qubit, the list of wires that contain it. + qubit_wires: IndexMap>, + /// For each tracked bit, the list of wires that contain it. + bit_wires: IndexMap>, + /// An ordered set of parameters found in operation arguments, and added as + /// new region inputs. + parameters: IndexMap>, + /// A list of input variables added to the hugr. + /// + /// Ordered according to their order in the function input. + parameter_vars: IndexSet, +} + +impl WireTracker { + /// Returns a new WireTracker with the given capacity. + pub fn with_capacity(qubit_count: usize, bit_count: usize) -> Self { + WireTracker { + wires: IndexMap::new(), + qubits: Vec::with_capacity(qubit_count), + bits: Vec::with_capacity(bit_count), + latest_qubit_tracker: IndexMap::with_capacity(qubit_count), + latest_bit_tracker: IndexMap::with_capacity(bit_count), + qubit_wires: IndexMap::with_capacity(qubit_count), + bit_wires: IndexMap::with_capacity(bit_count), + parameters: IndexMap::new(), + parameter_vars: IndexSet::new(), + } + } + + /// Closes the WireTracker. + /// + /// Returns: + /// - A list of qubit and bit elements, in the order they were added. + /// - A list of input parameter added to the hugr, in the order they were added. + pub(super) fn finish(self) -> IndexSet { + self.parameter_vars + } + + /// Returns a reference to the tracked qubit at the given index. + fn get_qubit(&self, id: TrackedQubitId) -> &TrackedQubit { + &self.qubits[id.0] + } + + /// Returns a reference to the tracked bit at the given index. + fn get_bit(&self, id: TrackedBitId) -> &TrackedBit { + &self.bits[id.0] + } + + /// Returns `true` if the given register is a known bit register. + pub(super) fn is_known_bit(&self, register: &PytketRegister) -> bool { + self.latest_bit_tracker + .contains_key(&RegisterHash::from(register)) + } + + /// Returns the list of known pytket registers, in the order they were registered. + pub(super) fn known_pytket_qubits(&self) -> impl Iterator { + self.latest_qubit_tracker + .iter() + .map(|(_, &elem_id)| self.get_qubit(elem_id)) + } + + /// Returns the list of known pytket bit registers, in the order they were registered. + pub(super) fn known_pytket_bits(&self) -> impl Iterator { + self.latest_bit_tracker + .iter() + .map(|(_, &elem_id)| self.get_bit(elem_id)) + } + + /// Track a new pytket qubit register. + /// + /// If the pytket register was already in the tracker, + /// marks the previous element as outdated. + /// + /// Returns the hash of the register. + pub(super) fn track_qubit( + &mut self, + qubit_reg: Arc, + ) -> Result { + check_register(&qubit_reg)?; + + let id = TrackedQubitId(self.qubits.len()); + let hash = RegisterHash::from(qubit_reg.as_ref()); + self.qubits.push(TrackedQubit::new(qubit_reg)); + if let Some(previous_id) = self.latest_qubit_tracker.insert(hash, id) { + self.qubits[previous_id.0].mark_outdated(); + } + Ok(id) + } + + /// Track a new pytket bit register. + /// + /// If the pytket register was already in the tracker, + /// marks the previous element as outdated. + /// + /// Returns the hash of the register. + pub(super) fn track_bit( + &mut self, + bit_reg: Arc, + ) -> Result { + check_register(&bit_reg)?; + + let id = TrackedBitId(self.bits.len()); + let hash = RegisterHash::from(bit_reg.as_ref()); + self.bits.push(TrackedBit::new(bit_reg)); + if let Some(previous_id) = self.latest_bit_tracker.insert(hash, id) { + self.bits[previous_id.0].mark_outdated(); + } + + Ok(id) + } + + /// Returns the latest tracked qubit for a pytket register. + /// + /// Returns an error if the register is not known. + /// + /// The returned element is guaranteed to be up to date (See [`TrackedQubit::is_outdated`]). + pub fn tracked_qubit_for_register( + &self, + register: impl AsRef, + ) -> Result<&TrackedQubit, Tk1DecodeError> { + let hash = RegisterHash::from(register.as_ref()); + let Some(id) = self.latest_qubit_tracker.get(&hash) else { + return Err(Tk1DecodeError::unknown_qubit_reg(register.as_ref())); + }; + Ok(self.get_qubit(*id)) + } + + /// Returns the latest tracked bit for a pytket register. + /// + /// Returns an error if the register is not known. + /// + /// The returned element is guaranteed to be up to date (See [`TrackedBit::is_outdated`]). + pub fn tracked_bit_for_register( + &self, + register: impl AsRef, + ) -> Result<&TrackedBit, Tk1DecodeError> { + let hash = RegisterHash::from(register.as_ref()); + let Some(id) = self.latest_bit_tracker.get(&hash) else { + return Err(Tk1DecodeError::unknown_bit_reg(register.as_ref())); + }; + Ok(self.get_bit(*id)) + } + + /// Returns a new set of [TrackedWires] for a list of + /// [`circuit_json::Command`][tket_json_rs::circuit_json::Command] inputs. + /// + /// Returns an error if a valid set cannot be found. + /// + /// # Arguments + /// + /// * `hugr` - The hugr to add the wires to. + /// * `args` - The list of pytket element ids to map to wires. + /// * `operation` - The name of the operation being decoded, used for error reporting. + /// * `params` - The list of parameters to load to wires. See [`WireTracker::load_parameter`] for more details. + /// + // TODO: We'll need to be able to decompose types when we need only _some_ + // of the elements they contain (E.g., extract a value from an array), + // and do it automatically here. + pub(super) fn wire_inputs_for_command<'r>( + &mut self, + hugr: &mut FunctionBuilder<&mut Hugr>, + qubit_args: impl IntoIterator, + bit_args: impl IntoIterator, + params: impl IntoIterator, + operation: &str, + ) -> Result { + // We need to return a set of wires that contain all the arguments. + // + // We collect this by checking the wires where each element is present, + // and collecting them in order. + let mut qubit_args: VecDeque<(TrackedQubitId, &PytketRegister)> = qubit_args + .into_iter() + .map( + |register| match self.latest_qubit_tracker.get(&RegisterHash::from(register)) { + Some(id) => Ok((*id, register)), + None => Err(Tk1DecodeError::unknown_qubit_reg(register)), + }, + ) + .collect::>()?; + let mut bit_args: VecDeque<(TrackedBitId, &PytketRegister)> = bit_args + .into_iter() + .map( + |register| match self.latest_bit_tracker.get(&RegisterHash::from(register)) { + Some(id) => Ok((*id, register)), + None => Err(Tk1DecodeError::unknown_bit_reg(register)), + }, + ) + .collect::>()?; + + let mut value_wires = Vec::new(); + while !qubit_args.is_empty() || !bit_args.is_empty() { + // Check candidate wires that only contain the elements we need, in the right order. + let filter_candidate_wire = |w: Wire| { + let mut wire_qubits = self.wires[&w].qubits.iter().peekable(); + let mut wire_bits = self.wires[&w].bits.iter().peekable(); + let mut q_args_iter = qubit_args.iter().map(|(id, _)| id); + let mut b_args_iter = bit_args.iter().map(|(id, _)| id); + + // Check that each argument appears as either a qubit or a bit + // in the wire, in the right order. + // + // We may have leftover arguments at the end, which we'll try to + // get from another wire. + while wire_qubits.peek().is_some() && wire_bits.peek().is_some() { + if let Some(qb) = wire_qubits.next() { + match q_args_iter.next() { + Some(arg) if qb == arg => continue, + _ => return false, + }; + } + if let Some(bit) = wire_bits.next() { + match b_args_iter.next() { + Some(arg) if bit == arg => continue, + _ => return false, + }; + } + return false; + } + true + }; + + let qubit_candidates = qubit_args + .front() + .into_iter() + .flat_map(|(id, _)| self.qubit_wires[id].iter()); + let bit_candidates = bit_args + .front() + .into_iter() + .flat_map(|(id, _)| self.bit_wires[id].iter()); + let candidate = qubit_candidates + .chain(bit_candidates) + .find(|&&w| filter_candidate_wire(w)); + + // If we found a candidate, add it to the list of wires. + match candidate { + Some(w) => { + // Consume the extracted args, and add the wire to the list. + let wire_data: WireData = self.wires[w].clone(); + qubit_args.drain(..wire_data.num_qubits()); + bit_args.drain(..wire_data.num_bits()); + value_wires.push(wire_data); + } + None => { + // In the future we may be able to decompose some wire containing `arg_ids[0]` internally. + // For now, we just report the error. + return Err(Tk1DecodeError::ArgumentCouldNotBeMapped { + operation: operation.to_string(), + qubit_args: qubit_args + .iter() + .map(|(_, elem)| elem.to_string()) + .collect(), + bit_args: bit_args.iter().map(|(_, elem)| elem.to_string()).collect(), + }); + } + } + } + + // Load the parameters. + let parameter_wires = params + .into_iter() + .map(|param| self.load_parameter(hugr, param)) + .collect_vec(); + + Ok(TrackedWires { + value_wires, + parameter_wires, + }) + } + + /// Returns the wire carrying a parameter. + /// + /// - If the parameter is a known algebraic operation, adds the required op and recurses on its inputs. + /// - If the parameter is a constant, a constant definition is added to the Hugr. + /// - If the parameter is a variable, adds a new `rotation` input to the region. + /// - If the parameter is a sympy expressions, adds it as a [`SympyOpDef`][crate::extension::sympy::SympyOpDef] black box. + pub fn load_parameter( + &mut self, + hugr: &mut FunctionBuilder<&mut Hugr>, + param: &str, + ) -> Arc { + fn process( + hugr: &mut FunctionBuilder<&mut Hugr>, + input_params: &mut IndexMap>, + param_vars: &mut IndexSet, + parsed: PytketParam, + param: &str, + ) -> Arc { + match parsed { + PytketParam::Constant(half_turns) => { + let value: Value = ConstF64::new(half_turns).into(); + let wire = hugr.add_load_const(value); + Arc::new(LoadedParameter::float(wire)) + } + PytketParam::Sympy(expr) => { + // store string in custom op. + let symb_op = symbolic_constant_op(expr.to_string()); + let wire = hugr.add_dataflow_op(symb_op, []).unwrap().out_wire(0); + Arc::new(LoadedParameter::rotation(wire)) + } + PytketParam::InputVariable { name } => { + // Special case for the name "pi", inserts a `ConstRotation::PI` instead. + if name == "pi" { + let value: Value = ConstF64::new(std::f64::consts::PI).into(); + let wire = hugr.add_load_const(value); + return Arc::new(LoadedParameter::float(wire)); + } + // Look it up in the input parameters to the circuit, and add a new wire if needed. + input_params + .entry(name.to_string()) + .or_insert_with(|| { + param_vars.insert(name.to_string()); + let wire = hugr.add_input(rotation_type()); + Arc::new(LoadedParameter::rotation(wire)) + }) + .clone() + } + PytketParam::Operation { op, args } => { + // We assume all operations take float inputs. + let input_wires = args + .into_iter() + .map(|arg| { + process(hugr, input_params, param_vars, arg, param) + .as_float(hugr) + .wire + }) + .collect_vec(); + // If any of the following asserts panics, it means we added invalid ops to the sympy parser. + let res = hugr.add_dataflow_op(op, input_wires).unwrap_or_else(|e| { + panic!("Error while decoding pytket operation parameter \"{param}\". {e}",) + }); + assert_eq!(res.num_value_outputs(), 1, "An operation decoded from the pytket op parameter \"{param}\" had {} outputs", res.num_value_outputs()); + Arc::new(LoadedParameter::float(res.out_wire(0))) + } + } + } + + let parsed = parse_pytket_param(param); + process( + hugr, + &mut self.parameters, + &mut self.parameter_vars, + parsed, + param, + ) + } + + /// Track a new wire, updating any tracked elements that are present in it. + pub fn track_wire( + &mut self, + wire: Wire, + ty: Arc, + qubits: impl IntoIterator, + bits: impl IntoIterator, + ) -> Result<(), Tk1DecodeError> { + let qubits = qubits + .into_iter() + .map(|q| self.track_qubit(q.pytket_register_arc())) + .collect::>()?; + let bits = bits + .into_iter() + .map(|b| self.track_bit(b.pytket_register_arc())) + .collect::>()?; + + for &q in &qubits { + self.qubit_wires.entry(q).or_default().push(wire); + } + for &b in &bits { + self.bit_wires.entry(b).or_default().push(wire); + } + + let wire_data = WireData { + wire, + ty, + qubits, + bits, + }; + self.wires.insert(wire, wire_data); + + Ok(()) + } + + pub(crate) fn register_input_parameter( + &mut self, + wire: Wire, + param: String, + ) -> Result<(), Tk1DecodeError> { + let entry = self.parameters.entry(param); + if let indexmap::map::Entry::Occupied(_) = &entry { + return Err(Tk1DecodeError::DuplicatedParameter { + param: entry.key().clone(), + }); + } + entry.insert_entry(Arc::new(LoadedParameter::rotation(wire))); + Ok(()) + } +} + +/// Only single-indexed registers are supported. +fn check_register(register: &PytketRegister) -> Result<(), Tk1DecodeError> { + if register.1.len() != 1 { + Err(Tk1DecodeError::MultiIndexedRegister { + register: register.to_string(), + }) + } else { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::serialize::pytket::decoder::Tk1DecodeError; + use hugr::extension::prelude::{bool_t, qb_t}; + use hugr::types::SumType; + use hugr::Node; + use rstest::{fixture, rstest}; + use std::sync::Arc; + use tket_json_rs::register::ElementId; + + #[fixture] + fn sample_wire(#[default(0)] wire_idx: usize) -> Wire { + Wire::new(Node::from(portgraph::NodeIndex::new(wire_idx)), 0) + } + + // Test basic WireTracker creation + #[rstest] + fn tracker_properties() { + let mut tracker = WireTracker::with_capacity(5, 3); + let qubit_reg = Arc::new(ElementId("q".to_string(), vec![0])); + let bit_reg = Arc::new(ElementId("c".to_string(), vec![0])); + let multi_indexed_reg = Arc::new(ElementId("q".to_string(), vec![0, 1])); + let wire1 = sample_wire(1); + + // Initially, everything is empty - test through public methods + assert_eq!(tracker.known_pytket_qubits().count(), 0); + assert_eq!(tracker.known_pytket_bits().count(), 0); + + // Track an invalid register name. + match tracker.track_qubit(multi_indexed_reg.clone()) { + Err(Tk1DecodeError::MultiIndexedRegister { register }) => { + assert_eq!(register, multi_indexed_reg.to_string()); + } + e => panic!("Expected MultiIndexedRegister error, got {e:?}"), + } + + // Getting the tracked qubits or bits for an unknown register should fail. + match tracker.tracked_qubit_for_register(&qubit_reg) { + Err(Tk1DecodeError::UnknownQubitRegister { register }) => { + assert_eq!(register, qubit_reg.to_string()); + } + e => panic!("Expected UnknownQubitRegister error, got {e:?}"), + } + match tracker.tracked_bit_for_register(&bit_reg) { + Err(Tk1DecodeError::UnknownBitRegister { register }) => { + assert_eq!(register, bit_reg.to_string()); + } + e => panic!("Expected UnknownBitRegister error, got {e:?}"), + } + + // Track a new qubit + let tracked_q_0 = tracker + .track_qubit(qubit_reg.clone()) + .expect("Should track qubit"); + assert_eq!(tracker.known_pytket_qubits().count(), 1); + assert_eq!(tracker.known_pytket_bits().count(), 0); + let tracked_qubit = tracker + .tracked_qubit_for_register(&qubit_reg) + .expect("Should find tracked qubit"); + assert!(!tracked_qubit.is_outdated()); + assert_eq!(tracked_qubit, tracker.get_qubit(tracked_q_0)); + + // Track the same qubit again, it should add a new TrackedQubit and mark the previous one as outdated + let tracked_q_1 = tracker + .track_qubit(qubit_reg.clone()) + .expect("Should track qubit again"); + assert_eq!(tracker.known_pytket_qubits().count(), 1); // still only one unique register + assert!(tracker.get_qubit(tracked_q_0).is_outdated()); + assert!(!tracker.get_qubit(tracked_q_1).is_outdated()); + let tracked_qubit = tracker + .tracked_qubit_for_register(&qubit_reg) + .expect("Should find latest tracked qubit") + .clone(); + assert_eq!(&tracked_qubit, tracker.get_qubit(tracked_q_1)); + + // Track a bit + let bit_id = tracker + .track_bit(bit_reg.clone()) + .expect("Should track bit"); + assert_eq!(tracker.known_pytket_bits().count(), 1); + assert!(!tracker.get_bit(bit_id).is_outdated()); + let tracked_bit = tracker + .tracked_bit_for_register(&bit_reg) + .expect("Should find tracked bit") + .clone(); + assert_eq!(&tracked_bit, tracker.get_bit(bit_id)); + + // Associate the bit and qubit with a wire. + tracker + .track_wire( + wire1, + Arc::new(SumType::new_tuple(vec![qb_t(), bool_t()]).into()), + vec![tracked_qubit.clone()], + vec![tracked_bit.clone()], + ) + .expect("Should track wire"); + } +} From 6e426757125c70b6fefc59d8fc20ce599282eacc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Wed, 13 Aug 2025 18:10:02 +0100 Subject: [PATCH 04/11] Actually compile the wire tracker module --- tket/src/serialize/pytket/decoder.rs | 9 +++++++++ tket/src/serialize/pytket/decoder/param.rs | 4 ---- tket/src/serialize/pytket/decoder/tracked_elem.rs | 4 ++++ tket/src/serialize/pytket/decoder/wires.rs | 5 ++++- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/tket/src/serialize/pytket/decoder.rs b/tket/src/serialize/pytket/decoder.rs index 50c353e4a..af5480c79 100644 --- a/tket/src/serialize/pytket/decoder.rs +++ b/tket/src/serialize/pytket/decoder.rs @@ -2,12 +2,20 @@ mod op; mod param; +mod tracked_elem; +mod wires; #[expect( unused_imports, reason = "Temporarily unused while we refactor the pytket decoder" )] pub use param::{LoadedParameter, LoadedParameterType}; +pub use tracked_elem::{TrackedBit, TrackedQubit}; +#[expect( + unused_imports, + reason = "Temporarily unused while we refactor the pytket decoder" +)] +pub use wires::{TrackedWires, WireData, WireTracker}; use std::collections::{HashMap, HashSet}; @@ -19,6 +27,7 @@ use hugr::ops::{OpType, Value}; use hugr::std_extensions::arithmetic::float_types::ConstF64; use hugr::types::Signature; use hugr::{Hugr, Wire}; +use tracked_elem::{TrackedBitId, TrackedQubitId}; use indexmap::IndexMap; use itertools::{EitherOrBoth, Itertools}; diff --git a/tket/src/serialize/pytket/decoder/param.rs b/tket/src/serialize/pytket/decoder/param.rs index cae1eed45..b9d5cda5a 100644 --- a/tket/src/serialize/pytket/decoder/param.rs +++ b/tket/src/serialize/pytket/decoder/param.rs @@ -48,10 +48,6 @@ impl LoadedParameter { } /// Returns the hugr type for the parameter. - #[expect( - dead_code, - reason = "Temporarily unused while we refactor the pytket decoder" - )] pub fn wire_type(&self) -> &Type { static FLOAT_TYPE: LazyLock = LazyLock::new(float64_type); static ROTATION_TYPE: LazyLock = LazyLock::new(rotation_type); diff --git a/tket/src/serialize/pytket/decoder/tracked_elem.rs b/tket/src/serialize/pytket/decoder/tracked_elem.rs index 4b914da1b..af0a9cb8a 100644 --- a/tket/src/serialize/pytket/decoder/tracked_elem.rs +++ b/tket/src/serialize/pytket/decoder/tracked_elem.rs @@ -1,4 +1,8 @@ //! Pytket qubit and bit elements that we track during decoding. +#![allow( + dead_code, + reason = "Temporarily unused while we refactor the pytket decoder" +)] use std::sync::{Arc, LazyLock}; diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index ecebcb51d..c987feaa8 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -1,5 +1,9 @@ //! Structures to keep track of pytket [`ElementId`][tket_json_rs::register::ElementId]s and //! their correspondence to wires in the hugr being defined. +#![expect( + dead_code, + reason = "Temporarily unused while we refactor the pytket decoder" +)] use std::collections::VecDeque; use std::sync::Arc; @@ -763,7 +767,6 @@ fn check_register(register: &PytketRegister) -> Result<(), Tk1DecodeError> { #[cfg(test)] mod tests { use super::*; - use crate::serialize::pytket::decoder::Tk1DecodeError; use hugr::extension::prelude::{bool_t, qb_t}; use hugr::types::SumType; use hugr::Node; From 0bb0c437d893088d0d8c0931ac0ce87cb8b7f0e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 14 Aug 2025 12:11:47 +0100 Subject: [PATCH 05/11] Simplify `wire_inputs_for_command` logic --- tket/src/serialize/pytket/decoder/wires.rs | 42 +++++++++------------- 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index c987feaa8..eb3d5528e 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -548,33 +548,23 @@ impl WireTracker { let mut value_wires = Vec::new(); while !qubit_args.is_empty() || !bit_args.is_empty() { // Check candidate wires that only contain the elements we need, in the right order. + // + // We may have leftover arguments at the end, which we'll try to + // get from another wire. let filter_candidate_wire = |w: Wire| { - let mut wire_qubits = self.wires[&w].qubits.iter().peekable(); - let mut wire_bits = self.wires[&w].bits.iter().peekable(); - let mut q_args_iter = qubit_args.iter().map(|(id, _)| id); - let mut b_args_iter = bit_args.iter().map(|(id, _)| id); - - // Check that each argument appears as either a qubit or a bit - // in the wire, in the right order. - // - // We may have leftover arguments at the end, which we'll try to - // get from another wire. - while wire_qubits.peek().is_some() && wire_bits.peek().is_some() { - if let Some(qb) = wire_qubits.next() { - match q_args_iter.next() { - Some(arg) if qb == arg => continue, - _ => return false, - }; - } - if let Some(bit) = wire_bits.next() { - match b_args_iter.next() { - Some(arg) if bit == arg => continue, - _ => return false, - }; - } - return false; - } - true + let wire_data = &self.wires[&w]; + let same_qubits = itertools::equal( + wire_data.qubits.iter(), + qubit_args + .iter() + .map(|(id, _)| id) + .take(wire_data.num_qubits()), + ); + let same_bits = itertools::equal( + wire_data.bits.iter(), + bit_args.iter().map(|(id, _)| id).take(wire_data.num_bits()), + ); + same_qubits && same_bits }; let qubit_candidates = qubit_args From 3739373b5f7817262118a5c20dec7ad2169e7999 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 14 Aug 2025 15:38:34 +0100 Subject: [PATCH 06/11] rename `Tk1DecodeError` to `PytketDecodeError` --- tket-py/src/circuit.rs | 2 +- tket/src/serialize/pytket.rs | 4 +- tket/src/serialize/pytket/decoder/wires.rs | 51 ++++++++++++---------- 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/tket-py/src/circuit.rs b/tket-py/src/circuit.rs index 9e2aa73d6..6ac06ef53 100644 --- a/tket-py/src/circuit.rs +++ b/tket-py/src/circuit.rs @@ -82,7 +82,7 @@ create_py_exception!( ); create_py_exception!( - tket::serialize::pytket::Tk1DecodeError, + tket::serialize::pytket::PytketDecodeError, PyTK1DecodeError, "Error type for the conversion between tket1 and tket operations." ); diff --git a/tket/src/serialize/pytket.rs b/tket/src/serialize/pytket.rs index 8dbbb347b..92c069c53 100644 --- a/tket/src/serialize/pytket.rs +++ b/tket/src/serialize/pytket.rs @@ -282,7 +282,7 @@ impl Tk1ConvertError { /// Error type for conversion between tket2 ops and pytket operations. #[derive(derive_more::Debug, Display, Error)] #[non_exhaustive] -pub enum Tk1DecodeError { +pub enum PytketDecodeError { /// The pytket circuit uses multi-indexed registers. // // This could be supported in the future, if there is a need for it. @@ -424,7 +424,7 @@ pub enum Tk1DecodeError { }, } -impl Tk1DecodeError { +impl PytketDecodeError { /// Create a new error with a custom message. pub fn custom(msg: impl ToString) -> Self { Self::CustomError { diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index eb3d5528e..73c779a15 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -23,7 +23,7 @@ use crate::serialize::pytket::decoder::{ LoadedParameter, TrackedBit, TrackedBitId, TrackedQubit, TrackedQubitId, }; use crate::serialize::pytket::extension::RegisterCount; -use crate::serialize::pytket::{RegisterHash, Tk1DecodeError}; +use crate::serialize::pytket::{PytketDecodeError, RegisterHash}; use crate::symbolic_constant_op; /// Tracked data for a wire in [`TrackedWires`]. @@ -236,7 +236,10 @@ impl TrackedWires { /// /// * `operation` - The name of the operation being decoded, used for error reporting. #[inline] - pub fn wires_arr(&self, operation: &str) -> Result<[Wire; N], Tk1DecodeError> { + pub fn wires_arr( + &self, + operation: &str, + ) -> Result<[Wire; N], PytketDecodeError> { let expected_values = N.saturating_sub(self.parameter_count()); let expected_params = N - expected_values; self.check_len(expected_values, expected_params, operation)?; @@ -270,10 +273,10 @@ impl TrackedWires { expected_values: usize, expected_params: usize, operation: &str, - ) -> Result<(), Tk1DecodeError> { + ) -> Result<(), PytketDecodeError> { if self.value_count() != expected_values || self.parameter_count() != expected_params { let types = self.wire_types().map(|ty| ty.to_string()).collect_vec(); - Err(Tk1DecodeError::UnexpectedInputWires { + Err(PytketDecodeError::UnexpectedInputWires { expected_values, expected_params, actual_values: self.value_count(), @@ -300,12 +303,12 @@ impl TrackedWires { expected_values: &[Type], expected_params: usize, operation: &str, - ) -> Result<(), Tk1DecodeError> { + ) -> Result<(), PytketDecodeError> { let vals = expected_values.iter(); if !itertools::equal(self.value_types(), vals) || self.parameter_count() != expected_params { let actual = self.value_types().collect_vec(); - Err(Tk1DecodeError::UnexpectedInputWires { + Err(PytketDecodeError::UnexpectedInputWires { expected_values: expected_values.len(), expected_params, actual_values: self.value_count(), @@ -433,7 +436,7 @@ impl WireTracker { pub(super) fn track_qubit( &mut self, qubit_reg: Arc, - ) -> Result { + ) -> Result { check_register(&qubit_reg)?; let id = TrackedQubitId(self.qubits.len()); @@ -454,7 +457,7 @@ impl WireTracker { pub(super) fn track_bit( &mut self, bit_reg: Arc, - ) -> Result { + ) -> Result { check_register(&bit_reg)?; let id = TrackedBitId(self.bits.len()); @@ -475,10 +478,10 @@ impl WireTracker { pub fn tracked_qubit_for_register( &self, register: impl AsRef, - ) -> Result<&TrackedQubit, Tk1DecodeError> { + ) -> Result<&TrackedQubit, PytketDecodeError> { let hash = RegisterHash::from(register.as_ref()); let Some(id) = self.latest_qubit_tracker.get(&hash) else { - return Err(Tk1DecodeError::unknown_qubit_reg(register.as_ref())); + return Err(PytketDecodeError::unknown_qubit_reg(register.as_ref())); }; Ok(self.get_qubit(*id)) } @@ -491,10 +494,10 @@ impl WireTracker { pub fn tracked_bit_for_register( &self, register: impl AsRef, - ) -> Result<&TrackedBit, Tk1DecodeError> { + ) -> Result<&TrackedBit, PytketDecodeError> { let hash = RegisterHash::from(register.as_ref()); let Some(id) = self.latest_bit_tracker.get(&hash) else { - return Err(Tk1DecodeError::unknown_bit_reg(register.as_ref())); + return Err(PytketDecodeError::unknown_bit_reg(register.as_ref())); }; Ok(self.get_bit(*id)) } @@ -521,7 +524,7 @@ impl WireTracker { bit_args: impl IntoIterator, params: impl IntoIterator, operation: &str, - ) -> Result { + ) -> Result { // We need to return a set of wires that contain all the arguments. // // We collect this by checking the wires where each element is present, @@ -531,7 +534,7 @@ impl WireTracker { .map( |register| match self.latest_qubit_tracker.get(&RegisterHash::from(register)) { Some(id) => Ok((*id, register)), - None => Err(Tk1DecodeError::unknown_qubit_reg(register)), + None => Err(PytketDecodeError::unknown_qubit_reg(register)), }, ) .collect::>()?; @@ -540,7 +543,7 @@ impl WireTracker { .map( |register| match self.latest_bit_tracker.get(&RegisterHash::from(register)) { Some(id) => Ok((*id, register)), - None => Err(Tk1DecodeError::unknown_bit_reg(register)), + None => Err(PytketDecodeError::unknown_bit_reg(register)), }, ) .collect::>()?; @@ -591,7 +594,7 @@ impl WireTracker { None => { // In the future we may be able to decompose some wire containing `arg_ids[0]` internally. // For now, we just report the error. - return Err(Tk1DecodeError::ArgumentCouldNotBeMapped { + return Err(PytketDecodeError::ArgumentCouldNotBeMapped { operation: operation.to_string(), qubit_args: qubit_args .iter() @@ -699,7 +702,7 @@ impl WireTracker { ty: Arc, qubits: impl IntoIterator, bits: impl IntoIterator, - ) -> Result<(), Tk1DecodeError> { + ) -> Result<(), PytketDecodeError> { let qubits = qubits .into_iter() .map(|q| self.track_qubit(q.pytket_register_arc())) @@ -731,10 +734,10 @@ impl WireTracker { &mut self, wire: Wire, param: String, - ) -> Result<(), Tk1DecodeError> { + ) -> Result<(), PytketDecodeError> { let entry = self.parameters.entry(param); if let indexmap::map::Entry::Occupied(_) = &entry { - return Err(Tk1DecodeError::DuplicatedParameter { + return Err(PytketDecodeError::DuplicatedParameter { param: entry.key().clone(), }); } @@ -744,9 +747,9 @@ impl WireTracker { } /// Only single-indexed registers are supported. -fn check_register(register: &PytketRegister) -> Result<(), Tk1DecodeError> { +fn check_register(register: &PytketRegister) -> Result<(), PytketDecodeError> { if register.1.len() != 1 { - Err(Tk1DecodeError::MultiIndexedRegister { + Err(PytketDecodeError::MultiIndexedRegister { register: register.to_string(), }) } else { @@ -784,7 +787,7 @@ mod tests { // Track an invalid register name. match tracker.track_qubit(multi_indexed_reg.clone()) { - Err(Tk1DecodeError::MultiIndexedRegister { register }) => { + Err(PytketDecodeError::MultiIndexedRegister { register }) => { assert_eq!(register, multi_indexed_reg.to_string()); } e => panic!("Expected MultiIndexedRegister error, got {e:?}"), @@ -792,13 +795,13 @@ mod tests { // Getting the tracked qubits or bits for an unknown register should fail. match tracker.tracked_qubit_for_register(&qubit_reg) { - Err(Tk1DecodeError::UnknownQubitRegister { register }) => { + Err(PytketDecodeError::UnknownQubitRegister { register }) => { assert_eq!(register, qubit_reg.to_string()); } e => panic!("Expected UnknownQubitRegister error, got {e:?}"), } match tracker.tracked_bit_for_register(&bit_reg) { - Err(Tk1DecodeError::UnknownBitRegister { register }) => { + Err(PytketDecodeError::UnknownBitRegister { register }) => { assert_eq!(register, bit_reg.to_string()); } e => panic!("Expected UnknownBitRegister error, got {e:?}"), From b2606b9d970aacc93a8b9dd8c8a2cbd1c7d81cb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 14 Aug 2025 18:24:24 +0100 Subject: [PATCH 07/11] review comments --- tket/src/serialize/pytket/decoder/wires.rs | 33 +++++++++++----------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index 73c779a15..e8e2116a7 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -93,7 +93,19 @@ impl WireData { } } -/// Tracked wires to a pytket operation. +/// Set of wires related to a pytket operation being decoded. +/// +/// Contains both _parameter_ and _value_ wires. +/// +/// The _parameter_ wires are wires that contain a single [`LoadedParameter`] +/// (either a float or a rotation) corresponding to the sympy expressions in the +/// operation arguments. +/// +/// The _value_ wires are wires that contain a collection of [`TrackedQubit`]s +/// and [`TrackedBit`]s. +/// +/// This set is passed to the implementer of `PytketDecoder` with the wires that +/// were found to contain the pytket registers used by the operation. #[derive(Debug, Clone)] pub struct TrackedWires { /// Computed list of wires corresponding to the arguments, @@ -147,9 +159,6 @@ impl TrackedWires { } /// Return an iterator over the wires and their types. - /// - /// This returns the wires as-is, without any additional conversions. - /// If you need to retrieve a specific wire type, use TODO #[inline] pub fn iter_values(&self) -> impl Iterator + Clone + '_ { self.value_wires.iter() @@ -190,15 +199,6 @@ impl TrackedWires { .flat_map(move |wd| wd.qubits(wire_tracker)) } - /// Returns the tracked qubit elements in the set of wires as an array. - #[inline] - pub fn qubits_arr( - &self, - wire_tracker: &WireTracker, - ) -> Option<[TrackedQubit; N]> { - self.qubits(wire_tracker).collect_array() - } - /// Returns the tracked bit elements in the set of wires. #[inline] pub fn bits<'d>( @@ -340,8 +340,8 @@ impl TrackedWires { /// operation, all previous references to it become "outdated". #[derive(Debug, Clone, Default)] pub struct WireTracker { - /// A list of tracked wires, with their type and list of - /// tracked pytket elements and arguments. + /// A map of wires being tracked, with their type and list of + /// tracked pytket registers and parameters. wires: IndexMap, /// The list of tracked qubit elements. /// @@ -685,12 +685,11 @@ impl WireTracker { } } - let parsed = parse_pytket_param(param); process( hugr, &mut self.parameters, &mut self.parameter_vars, - parsed, + parse_pytket_param(param), param, ) } From e998a125daa326decb46a5b64ca03f6559ec5da2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Thu, 14 Aug 2025 18:41:52 +0100 Subject: [PATCH 08/11] Add a context wrapper to the DecoderError, drop `operation` string arguments --- tket/src/serialize/pytket.rs | 107 +++++++++++++++------ tket/src/serialize/pytket/decoder/wires.rs | 62 ++++++------ 2 files changed, 106 insertions(+), 63 deletions(-) diff --git a/tket/src/serialize/pytket.rs b/tket/src/serialize/pytket.rs index 23050cea5..b8ae8c6b9 100644 --- a/tket/src/serialize/pytket.rs +++ b/tket/src/serialize/pytket.rs @@ -280,9 +280,81 @@ impl PytketEncodeError { } /// Error type for conversion between tket2 ops and pytket operations. -#[derive(derive_more::Debug, Display, Error)] +#[derive(derive_more::Debug, Display, Error, Clone)] #[non_exhaustive] -pub enum PytketDecodeError { +#[display( + "{inner}{context}", + context = { + match (pytket_op, hugr_op) { + (Some(pytket_op), Some(hugr_op)) => format!(". While decoding a pytket {pytket_op} as a hugr {hugr_op}"), + (Some(pytket_op), None) => format!(". While decoding a pytket {pytket_op}"), + (None, Some(hugr_op)) => format!(". While decoding a hugr {hugr_op}"), + (None, None) => String::new(), + } + }, +)] +pub struct PytketDecodeError { + /// The kind of error. + pub inner: PytketDecodeErrorInner, + /// The pytket operation that caused the error, if applicable. + pub pytket_op: Option, + /// The hugr operation that caused the error, if applicable. + pub hugr_op: Option, +} + +impl PytketDecodeError { + /// Create a new error with a custom message. + pub fn custom(msg: impl ToString) -> Self { + PytketDecodeErrorInner::CustomError { + msg: msg.to_string(), + } + .into() + } + + /// Create an error for an unknown qubit register. + pub fn unknown_qubit_reg(register: &tket_json_rs::register::ElementId) -> Self { + PytketDecodeErrorInner::UnknownQubitRegister { + register: register.to_string(), + } + .into() + } + + /// Create an error for an unknown bit register. + pub fn unknown_bit_reg(register: &tket_json_rs::register::ElementId) -> Self { + PytketDecodeErrorInner::UnknownBitRegister { + register: register.to_string(), + } + .into() + } + + /// Add the pytket operation name to the error. + pub fn pytket_op(mut self, op: &tket_json_rs::OpType) -> Self { + self.pytket_op = Some(format!("{op:?}")); + self + } + + /// Add the hugr operation name to the error. + pub fn hugr_op(mut self, op: impl ToString) -> Self { + self.hugr_op = Some(op.to_string()); + self + } +} + +impl From for PytketDecodeError { + fn from(inner: PytketDecodeErrorInner) -> Self { + Self { + inner, + pytket_op: None, + hugr_op: None, + } + } +} + +/// Error variants of [`PytketDecodeError`], signalling errors during the +/// conversion between tket2 ops and pytket operations. +#[derive(derive_more::Debug, Display, Error, Clone)] +#[non_exhaustive] +pub enum PytketDecodeErrorInner { /// The pytket circuit uses multi-indexed registers. // // This could be supported in the future, if there is a need for it. @@ -340,13 +412,11 @@ pub enum PytketDecodeError { // // Some of this errors will be avoided in the future once we are able to decompose complex types automatically. #[display( - "Could not find a wire with the required qubit arguments [{qubit_args:?}] and bit arguments [{bit_args:?}] for operation {operation}.", + "Could not find a wire with the required qubit arguments [{qubit_args:?}] and bit arguments [{bit_args:?}].", qubit_args = qubit_args.iter().join(", "), bit_args = bit_args.iter().join(", "), )] ArgumentCouldNotBeMapped { - /// The operation type that was being decoded. - operation: String, /// The qubit arguments that couldn't be mapped. qubit_args: Vec, /// The bit arguments that couldn't be mapped. @@ -354,7 +424,7 @@ pub enum PytketDecodeError { }, /// Found an unexpected number of input wires when decoding an operation. #[display( - "Expected {expected_values} input value wires{expected_types} and {expected_params} input parameters when decoding a {operation}, but found {actual_values} values{actual_types} and {actual_params} parameters.", + "Expected {expected_values} input value wires{expected_types} and {expected_params} input parameters, but found {actual_values} values{actual_types} and {actual_params} parameters.", expected_types = match expected_types { None => "".to_string(), Some(tys) => format!(" with types [{}]", tys.iter().join(", ")), @@ -377,8 +447,6 @@ pub enum PytketDecodeError { expected_types: Option>, /// The actual types of the input wires. actual_types: Option>, - /// The operation type that was being decoded. - operation: String, }, /// Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires. #[display( @@ -424,29 +492,6 @@ pub enum PytketDecodeError { }, } -impl PytketDecodeError { - /// Create a new error with a custom message. - pub fn custom(msg: impl ToString) -> Self { - Self::CustomError { - msg: msg.to_string(), - } - } - - /// Create an error for an unknown qubit register. - pub fn unknown_qubit_reg(register: &tket_json_rs::register::ElementId) -> Self { - Self::UnknownQubitRegister { - register: register.to_string(), - } - } - - /// Create an error for an unknown bit register. - pub fn unknown_bit_reg(register: &tket_json_rs::register::ElementId) -> Self { - Self::UnknownBitRegister { - register: register.to_string(), - } - } -} - /// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map, /// avoiding string and vector clones on lookup. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index e8e2116a7..971a2b157 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -23,7 +23,7 @@ use crate::serialize::pytket::decoder::{ LoadedParameter, TrackedBit, TrackedBitId, TrackedQubit, TrackedQubitId, }; use crate::serialize::pytket::extension::RegisterCount; -use crate::serialize::pytket::{PytketDecodeError, RegisterHash}; +use crate::serialize::pytket::{PytketDecodeError, PytketDecodeErrorInner, RegisterHash}; use crate::symbolic_constant_op; /// Tracked data for a wire in [`TrackedWires`]. @@ -231,18 +231,11 @@ impl TrackedWires { /// Returns the wires in this tracked wires as an array of types. /// /// Returns an error if the number of wires is not equal to `N`. - /// - /// # Arguments - /// - /// * `operation` - The name of the operation being decoded, used for error reporting. #[inline] - pub fn wires_arr( - &self, - operation: &str, - ) -> Result<[Wire; N], PytketDecodeError> { + pub fn wires_arr(&self) -> Result<[Wire; N], PytketDecodeError> { let expected_values = N.saturating_sub(self.parameter_count()); let expected_params = N - expected_values; - self.check_len(expected_values, expected_params, operation)?; + self.check_len(expected_values, expected_params)?; Ok(self .wires() .collect_array() @@ -267,24 +260,22 @@ impl TrackedWires { /// /// * `expected_values` - The expected number of value wires. /// * `expected_params` - The expected number of parameter wires. - /// * `operation` - The name of the operation being decoded, used for error reporting. pub fn check_len( &self, expected_values: usize, expected_params: usize, - operation: &str, ) -> Result<(), PytketDecodeError> { if self.value_count() != expected_values || self.parameter_count() != expected_params { let types = self.wire_types().map(|ty| ty.to_string()).collect_vec(); - Err(PytketDecodeError::UnexpectedInputWires { + Err(PytketDecodeErrorInner::UnexpectedInputWires { expected_values, expected_params, actual_values: self.value_count(), actual_params: self.parameter_count(), expected_types: None, actual_types: Some(types), - operation: operation.to_string(), - }) + } + .into()) } else { Ok(()) } @@ -297,18 +288,16 @@ impl TrackedWires { /// * `expected_values` - The expected types of the value wires. /// * `expected_params` - The expected number of parameters. Note that these may be either `float` or `rotation`-typed. /// Use [`LoadedParameter::with_type`] to cast them as needed. - /// * `operation` - The name of the operation being decoded, used for error reporting. pub fn check_types( &self, expected_values: &[Type], expected_params: usize, - operation: &str, ) -> Result<(), PytketDecodeError> { let vals = expected_values.iter(); if !itertools::equal(self.value_types(), vals) || self.parameter_count() != expected_params { let actual = self.value_types().collect_vec(); - Err(PytketDecodeError::UnexpectedInputWires { + Err(PytketDecodeErrorInner::UnexpectedInputWires { expected_values: expected_values.len(), expected_params, actual_values: self.value_count(), @@ -320,8 +309,8 @@ impl TrackedWires { .collect_vec(), ), actual_types: Some(actual.iter().map(|ty| ty.to_string()).collect_vec()), - operation: operation.to_string(), - }) + } + .into()) } else { Ok(()) } @@ -511,7 +500,6 @@ impl WireTracker { /// /// * `hugr` - The hugr to add the wires to. /// * `args` - The list of pytket element ids to map to wires. - /// * `operation` - The name of the operation being decoded, used for error reporting. /// * `params` - The list of parameters to load to wires. See [`WireTracker::load_parameter`] for more details. /// // TODO: We'll need to be able to decompose types when we need only _some_ @@ -523,7 +511,6 @@ impl WireTracker { qubit_args: impl IntoIterator, bit_args: impl IntoIterator, params: impl IntoIterator, - operation: &str, ) -> Result { // We need to return a set of wires that contain all the arguments. // @@ -594,14 +581,14 @@ impl WireTracker { None => { // In the future we may be able to decompose some wire containing `arg_ids[0]` internally. // For now, we just report the error. - return Err(PytketDecodeError::ArgumentCouldNotBeMapped { - operation: operation.to_string(), + return Err(PytketDecodeErrorInner::ArgumentCouldNotBeMapped { qubit_args: qubit_args .iter() .map(|(_, elem)| elem.to_string()) .collect(), bit_args: bit_args.iter().map(|(_, elem)| elem.to_string()).collect(), - }); + } + .into()); } } } @@ -736,9 +723,10 @@ impl WireTracker { ) -> Result<(), PytketDecodeError> { let entry = self.parameters.entry(param); if let indexmap::map::Entry::Occupied(_) = &entry { - return Err(PytketDecodeError::DuplicatedParameter { + return Err(PytketDecodeErrorInner::DuplicatedParameter { param: entry.key().clone(), - }); + } + .into()); } entry.insert_entry(Arc::new(LoadedParameter::rotation(wire))); Ok(()) @@ -748,9 +736,10 @@ impl WireTracker { /// Only single-indexed registers are supported. fn check_register(register: &PytketRegister) -> Result<(), PytketDecodeError> { if register.1.len() != 1 { - Err(PytketDecodeError::MultiIndexedRegister { + Err(PytketDecodeErrorInner::MultiIndexedRegister { register: register.to_string(), - }) + } + .into()) } else { Ok(()) } @@ -786,7 +775,10 @@ mod tests { // Track an invalid register name. match tracker.track_qubit(multi_indexed_reg.clone()) { - Err(PytketDecodeError::MultiIndexedRegister { register }) => { + Err(PytketDecodeError { + inner: PytketDecodeErrorInner::MultiIndexedRegister { register }, + .. + }) => { assert_eq!(register, multi_indexed_reg.to_string()); } e => panic!("Expected MultiIndexedRegister error, got {e:?}"), @@ -794,13 +786,19 @@ mod tests { // Getting the tracked qubits or bits for an unknown register should fail. match tracker.tracked_qubit_for_register(&qubit_reg) { - Err(PytketDecodeError::UnknownQubitRegister { register }) => { + Err(PytketDecodeError { + inner: PytketDecodeErrorInner::UnknownQubitRegister { register }, + .. + }) => { assert_eq!(register, qubit_reg.to_string()); } e => panic!("Expected UnknownQubitRegister error, got {e:?}"), } match tracker.tracked_bit_for_register(&bit_reg) { - Err(PytketDecodeError::UnknownBitRegister { register }) => { + Err(PytketDecodeError { + inner: PytketDecodeErrorInner::UnknownBitRegister { register }, + .. + }) => { assert_eq!(register, bit_reg.to_string()); } e => panic!("Expected UnknownBitRegister error, got {e:?}"), From be078957a47170b05f563884ad62c89de2bf2575 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Fri, 15 Aug 2025 02:44:16 +0100 Subject: [PATCH 09/11] Replace `wire_inputs_for_command` with a way simpler `find_typed_wires` --- tket/src/serialize/pytket.rs | 92 ++++-- .../serialize/pytket/decoder/tracked_elem.rs | 96 ++++++- tket/src/serialize/pytket/decoder/wires.rs | 263 +++++++++--------- 3 files changed, 299 insertions(+), 152 deletions(-) diff --git a/tket/src/serialize/pytket.rs b/tket/src/serialize/pytket.rs index b8ae8c6b9..bcdac46c0 100644 --- a/tket/src/serialize/pytket.rs +++ b/tket/src/serialize/pytket.rs @@ -29,6 +29,7 @@ use tket_json_rs::circuit_json::SerialCircuit; use tket_json_rs::register::{Bit, ElementId, Qubit}; use crate::circuit::Circuit; +use crate::serialize::pytket::extension::RegisterCount; use self::decoder::Tk1DecoderContext; pub use crate::passes::pytket::lower_to_pytket; @@ -358,19 +359,19 @@ pub enum PytketDecodeErrorInner { /// The pytket circuit uses multi-indexed registers. // // This could be supported in the future, if there is a need for it. - #[display("Register {register} in the circuit has multiple indices. Tket2 does not support multi-indexed registers.")] + #[display("Register {register} in the circuit has multiple indices. Tket2 does not support multi-indexed registers")] MultiIndexedRegister { /// The register name. register: String, }, /// Found an unexpected register name. - #[display("Found an unknown qubit register name: {register}.")] + #[display("Found an unknown qubit register name: {register}")] UnknownQubitRegister { /// The unknown register name. register: String, }, /// Found an unexpected bit register name. - #[display("Found an unknown bit register name: {register}.")] + #[display("Found an unknown bit register name: {register}")] UnknownBitRegister { /// The unknown register name. register: String, @@ -379,7 +380,7 @@ pub enum PytketDecodeErrorInner { /// /// The expected number of qubits and bits may be different depending on the [`PytketTypeTranslator`][extension::PytketTypeTranslator]s used in the decoder config. #[display( - "The given input types {input_types} to use for the HUGR's input wires are not compatible with the number of qubits and bits in the pytket circuit. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits.", + "The given input types {input_types} to use for the HUGR's input wires are not compatible with the number of qubits and bits in the pytket circuit. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits", input_types = input_types.iter().join(", "), )] InvalidInputSignature { @@ -398,7 +399,7 @@ pub enum PytketDecodeErrorInner { /// /// We don't do any kind of type conversion, so this depends solely on the last operation to update each register. #[display( - "The expected output types {expected_types} are not compatible with the actual output types {actual_types}, obtained from decoding the pytket circuit.", + "The expected output types {expected_types} are not compatible with the actual output types {actual_types}, obtained from decoding the pytket circuit", expected_types = expected_types.iter().join(", "), actual_types = actual_types.iter().join(", "), )] @@ -412,7 +413,7 @@ pub enum PytketDecodeErrorInner { // // Some of this errors will be avoided in the future once we are able to decompose complex types automatically. #[display( - "Could not find a wire with the required qubit arguments [{qubit_args:?}] and bit arguments [{bit_args:?}].", + "Could not find a wire with the required qubit arguments [{qubit_args:?}] and bit arguments [{bit_args:?}]", qubit_args = qubit_args.iter().join(", "), bit_args = bit_args.iter().join(", "), )] @@ -424,16 +425,16 @@ pub enum PytketDecodeErrorInner { }, /// Found an unexpected number of input wires when decoding an operation. #[display( - "Expected {expected_values} input value wires{expected_types} and {expected_params} input parameters, but found {actual_values} values{actual_types} and {actual_params} parameters.", - expected_types = match expected_types { - None => "".to_string(), - Some(tys) => format!(" with types [{}]", tys.iter().join(", ")), - }, - actual_types = match actual_types { - None => "".to_string(), - Some(tys) => format!(" with types [{}]", tys.iter().join(", ")), - }, - )] + "Expected {expected_values} input value wires{expected_types} and {expected_params} input parameters, but found {actual_values} values{actual_types} and {actual_params} parameters", + expected_types = match expected_types { + None => "".to_string(), + Some(tys) => format!(" with types [{}]", tys.iter().join(", ")), + }, + actual_types = match actual_types { + None => "".to_string(), + Some(tys) => format!(" with types [{}]", tys.iter().join(", ")), + }, + )] UnexpectedInputWires { /// The expected amount of input wires. expected_values: usize, @@ -448,9 +449,20 @@ pub enum PytketDecodeErrorInner { /// The actual types of the input wires. actual_types: Option>, }, + /// Found an unexpected input type when decoding an operation. + #[display( + "Found an unexpected type {unknown_type} in the input wires, in input signature ({all_types})", + all_types = all_types.iter().join(", "), + )] + UnexpectedInputType { + /// The unknown type. + unknown_type: String, + /// All the input types specified for the operation. + all_types: Vec, + }, /// Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires. #[display( - "Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits in the node outputs." + "Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits in the node outputs" )] UnexpectedNodeOutput { /// The expected number of qubits. @@ -475,7 +487,7 @@ pub enum PytketDecodeErrorInner { param: String, }, /// Not enough parameter names given for the input signature. - #[display("Tried to initialize a pytket circuit decoder with {num_params_given} given parameter names, but more were required by the input signature.")] + #[display("Tried to initialize a pytket circuit decoder with {num_params_given} given parameter names, but more were required by the input signature")] MissingParametersInInput { /// The number of parameters given. num_params_given: usize, @@ -483,13 +495,55 @@ pub enum PytketDecodeErrorInner { /// We don't support complex types containing parameters in the input. // // This restriction may be relaxed in the future. - #[display("Complex type {ty} contains {num_params} inside it. We only support input parameters in standalone 'float' or 'rotation'-typed wires.")] + #[display("Complex type {ty} contains {num_params} inside it. We only support input parameters in standalone 'float' or 'rotation'-typed wires")] UnsupportedParametersInInput { /// The type that contains the parameters. ty: String, /// The number of parameters in the type. num_params: usize, }, + /// We couldn't find a wire that contains the required type. + #[display( + "Could not find a wire with type {ty} that contains {expected_arguments}", + expected_arguments = match (qubit_args.is_empty(), bit_args.is_empty()) { + (true, true) => "no arguments".to_string(), + (true, false) => format!("pytket bit arguments [{}]", bit_args.iter().join(", ")), + (false, true) => format!("pytket qubit arguments [{}]", qubit_args.iter().join(", ")), + (false, false) => format!("pytket qubit and bit arguments [{}] and [{}]", qubit_args.iter().join(", "), bit_args.iter().join(", ")), + }, + )] + NoMatchingWire { + /// The type that couldn't be found. + ty: String, + /// The qubit registers expected in the wire. + qubit_args: Vec, + /// The bit registers expected in the wire. + bit_args: Vec, + }, + /// The number of pytket registers expected for an operation is not enough. + /// + /// This is usually caused by a mismatch between the input signature and the number of registers in the pytket circuit. + /// + /// The expected number of registers may be different depending on the [`PytketTypeTranslator`][extension::PytketTypeTranslator]s used in the decoder config. + #[display( + "Expected {expected_count} to map types ({expected_types}), but only got {actual_count}", + expected_types = expected_types.iter().join(", "), + )] + NotEnoughPytketRegisters { + /// The types we tried to get wires for. + expected_types: Vec, + /// The number of registers required by the types. + expected_count: RegisterCount, + /// The number of registers we actually got. + actual_count: RegisterCount, + }, +} + +impl PytketDecodeErrorInner { + /// Wrap the error in a [`PytketDecodeError`]. + pub fn wrap(self) -> PytketDecodeError { + PytketDecodeError::from(self) + } } /// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map, diff --git a/tket/src/serialize/pytket/decoder/tracked_elem.rs b/tket/src/serialize/pytket/decoder/tracked_elem.rs index af0a9cb8a..94255b68a 100644 --- a/tket/src/serialize/pytket/decoder/tracked_elem.rs +++ b/tket/src/serialize/pytket/decoder/tracked_elem.rs @@ -4,12 +4,15 @@ reason = "Temporarily unused while we refactor the pytket decoder" )] +use std::hash::Hasher; use std::sync::{Arc, LazyLock}; use hugr::extension::prelude::{bool_t, qb_t}; use hugr::types::Type; use tket_json_rs::register::ElementId as PytketRegister; +use crate::serialize::pytket::RegisterHash; + /// An internal lightweight identifier for a [`TrackedQubit`] in the decoder. #[derive(Clone, Copy, Debug, derive_more::Display, Hash, PartialEq, Eq, PartialOrd, Ord)] pub(super) struct TrackedQubitId(#[display(transparent)] pub usize); @@ -25,10 +28,19 @@ pub(super) struct TrackedBitId(#[display(transparent)] pub usize); /// /// Outdated values no longer correspond to a pytket circuit register, but they /// can still be found in the wires of the hugr being extracted. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Eq, derive_more::Display)] +#[display("{reg}")] pub struct TrackedQubit { + /// The id of this tracked qubit in the [`WireTracker`]. + id: TrackedQubitId, + /// Whether this tracked qubit is outdated, meaning that we have seen the + /// register in a newer wire. outdated: bool, + /// The pytket register for this tracked element. reg: Arc, + /// The hash of the pytket register for this tracked element, used to + /// speed up hashing and equality checks. + reg_hash: RegisterHash, } /// An identifier for a pytket bit register in the data carried by a wire. @@ -38,18 +50,38 @@ pub struct TrackedQubit { /// /// Outdated values no longer correspond to a pytket circuit register, but they /// can still be found in the wires of the hugr being extracted. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Eq, derive_more::Display)] +#[display("{reg}")] pub struct TrackedBit { + /// The id of this tracked bit in the [`WireTracker`]. + id: TrackedBitId, + /// Whether this tracked bit is outdated, meaning that we have seen the + /// register in a newer wire. outdated: bool, + /// The pytket register for this tracked element. reg: Arc, + /// The hash of the pytket register for this tracked element, used to + reg_hash: RegisterHash, } impl TrackedQubit { - /// Return a new tracked qubit. - pub(super) fn new(reg: Arc) -> Self { + /// Returns a new tracked qubit. + pub(super) fn new(id: TrackedQubitId, reg: Arc) -> Self { + let reg_hash = RegisterHash::from(reg.as_ref()); + Self::new_with_hash(id, reg, reg_hash) + } + + /// Returns a new tracked qubit. + pub(super) fn new_with_hash( + id: TrackedQubitId, + reg: Arc, + reg_hash: RegisterHash, + ) -> Self { Self { + id, outdated: false, reg, + reg_hash, } } @@ -69,6 +101,11 @@ impl TrackedQubit { QUBIT_TYPE.clone() } + /// Returns the id of this tracked qubit. + pub(super) fn id(&self) -> TrackedQubitId { + self.id + } + /// Returns `true` if the element has been overwritten by a new value. pub fn is_outdated(&self) -> bool { self.outdated @@ -82,10 +119,22 @@ impl TrackedQubit { impl TrackedBit { /// Returns a new tracked bit. - pub(super) fn new(reg: Arc) -> Self { + pub(super) fn new(id: TrackedBitId, reg: Arc) -> Self { + let reg_hash = RegisterHash::from(reg.as_ref()); + Self::new_with_hash(id, reg, reg_hash) + } + + /// Returns a new tracked bit. + pub(super) fn new_with_hash( + id: TrackedBitId, + reg: Arc, + reg_hash: RegisterHash, + ) -> Self { Self { + id, outdated: false, reg, + reg_hash, } } @@ -105,6 +154,11 @@ impl TrackedBit { BOOL_TYPE.clone() } + /// Returns the id of this tracked bit. + pub(super) fn id(&self) -> TrackedBitId { + self.id + } + /// Returns `true` if the element has been overwritten by a new value. pub fn is_outdated(&self) -> bool { self.outdated @@ -116,6 +170,34 @@ impl TrackedBit { } } +impl PartialEq for TrackedQubit { + fn eq(&self, other: &Self) -> bool { + self.id == other.id && self.outdated == other.outdated && self.reg_hash == other.reg_hash + } +} + +impl PartialEq for TrackedBit { + fn eq(&self, other: &Self) -> bool { + self.id == other.id && self.outdated == other.outdated && self.reg_hash == other.reg_hash + } +} + +impl std::hash::Hash for TrackedQubit { + fn hash(&self, state: &mut H) { + self.id.hash(state); + self.outdated.hash(state); + self.reg_hash.hash(state); + } +} + +impl std::hash::Hash for TrackedBit { + fn hash(&self, state: &mut H) { + self.id.hash(state); + self.outdated.hash(state); + self.reg_hash.hash(state); + } +} + #[cfg(test)] mod tests { use super::*; @@ -128,7 +210,7 @@ mod tests { #[rstest] fn tracked_qubit_basic_behaviour() { let reg = Arc::new(ElementId("q".to_string(), vec![0])); - let mut tq = TrackedQubit::new(reg.clone()); + let mut tq = TrackedQubit::new(TrackedQubitId(0), reg.clone()); assert!(!tq.is_outdated()); assert_eq!(tq.pytket_register(), &*reg); @@ -142,7 +224,7 @@ mod tests { #[rstest] fn tracked_bit_basic_behaviour() { let reg = Arc::new(ElementId("c".to_string(), vec![1])); - let mut tb = TrackedBit::new(reg.clone()); + let mut tb = TrackedBit::new(TrackedBitId(0), reg.clone()); assert!(!tb.is_outdated()); assert_eq!(tb.pytket_register(), &*reg); diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index 971a2b157..6fc738dad 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use hugr::builder::{Dataflow as _, FunctionBuilder}; use hugr::ops::Value; -use hugr::std_extensions::arithmetic::float_types::ConstF64; +use hugr::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; use hugr::types::Type; use hugr::{Hugr, Wire}; use indexmap::{IndexMap, IndexSet}; @@ -18,6 +18,7 @@ use itertools::Itertools; use tket_json_rs::register::ElementId as PytketRegister; use crate::extension::rotation::rotation_type; +use crate::serialize::pytket::config::TypeTranslatorSet; use crate::serialize::pytket::decoder::param::parser::{parse_pytket_param, PytketParam}; use crate::serialize::pytket::decoder::{ LoadedParameter, TrackedBit, TrackedBitId, TrackedQubit, TrackedQubitId, @@ -421,20 +422,23 @@ impl WireTracker { /// If the pytket register was already in the tracker, /// marks the previous element as outdated. /// - /// Returns the hash of the register. + /// If the [`RegisterHash`] has already been computed, it can be passed in + /// to avoid recomputing it. pub(super) fn track_qubit( &mut self, qubit_reg: Arc, - ) -> Result { + reg_hash: Option, + ) -> Result<&TrackedQubit, PytketDecodeError> { check_register(&qubit_reg)?; let id = TrackedQubitId(self.qubits.len()); - let hash = RegisterHash::from(qubit_reg.as_ref()); - self.qubits.push(TrackedQubit::new(qubit_reg)); + let hash = reg_hash.unwrap_or_else(|| RegisterHash::from(qubit_reg.as_ref())); + self.qubits + .push(TrackedQubit::new_with_hash(id, qubit_reg, hash)); if let Some(previous_id) = self.latest_qubit_tracker.insert(hash, id) { self.qubits[previous_id.0].mark_outdated(); } - Ok(id) + Ok(self.get_qubit(id)) } /// Track a new pytket bit register. @@ -442,21 +446,23 @@ impl WireTracker { /// If the pytket register was already in the tracker, /// marks the previous element as outdated. /// - /// Returns the hash of the register. + /// If the [`RegisterHash`] has already been computed, it can be passed in + /// to avoid recomputing it. pub(super) fn track_bit( &mut self, bit_reg: Arc, - ) -> Result { + reg_hash: Option, + ) -> Result<&TrackedBit, PytketDecodeError> { check_register(&bit_reg)?; let id = TrackedBitId(self.bits.len()); - let hash = RegisterHash::from(bit_reg.as_ref()); - self.bits.push(TrackedBit::new(bit_reg)); + let hash = reg_hash.unwrap_or_else(|| RegisterHash::from(bit_reg.as_ref())); + self.bits.push(TrackedBit::new_with_hash(id, bit_reg, hash)); if let Some(previous_id) = self.latest_bit_tracker.insert(hash, id) { self.bits[previous_id.0].mark_outdated(); } - Ok(id) + Ok(self.get_bit(id)) } /// Returns the latest tracked qubit for a pytket register. @@ -466,11 +472,11 @@ impl WireTracker { /// The returned element is guaranteed to be up to date (See [`TrackedQubit::is_outdated`]). pub fn tracked_qubit_for_register( &self, - register: impl AsRef, + register: &PytketRegister, ) -> Result<&TrackedQubit, PytketDecodeError> { - let hash = RegisterHash::from(register.as_ref()); + let hash = RegisterHash::from(register); let Some(id) = self.latest_qubit_tracker.get(&hash) else { - return Err(PytketDecodeError::unknown_qubit_reg(register.as_ref())); + return Err(PytketDecodeError::unknown_qubit_reg(register)); }; Ok(self.get_qubit(*id)) } @@ -482,126 +488,120 @@ impl WireTracker { /// The returned element is guaranteed to be up to date (See [`TrackedBit::is_outdated`]). pub fn tracked_bit_for_register( &self, - register: impl AsRef, + register: &PytketRegister, ) -> Result<&TrackedBit, PytketDecodeError> { - let hash = RegisterHash::from(register.as_ref()); + let hash = RegisterHash::from(register); let Some(id) = self.latest_bit_tracker.get(&hash) else { - return Err(PytketDecodeError::unknown_bit_reg(register.as_ref())); + return Err(PytketDecodeError::unknown_bit_reg(register)); }; Ok(self.get_bit(*id)) } - /// Returns a new set of [TrackedWires] for a list of - /// [`circuit_json::Command`][tket_json_rs::circuit_json::Command] inputs. + /// Returns the list of wires that contain the given qubit. + fn qubit_wires(&self, qubit: &TrackedQubit) -> impl Iterator + '_ { + self.qubit_wires[&qubit.id()].iter().copied() + } + + /// Returns the list of wires that contain the given bit. + fn bit_wires(&self, bit: &TrackedBit) -> impl Iterator + '_ { + self.bit_wires[&bit.id()].iter().copied() + } + + /// Returns a new set of [TrackedWires] for a list of [`TrackedQubit`]s, + /// [`TrackedBit`]s, and [`LoadedParameter`]s following the required types. /// - /// Returns an error if a valid set cannot be found. + /// Returns an error if a valid set of wires with the given types cannot be + /// found. /// - /// # Arguments + /// The qubit and bit arguments are only consumed as required by the types. + /// Some registers may be left unused. /// - /// * `hugr` - The hugr to add the wires to. - /// * `args` - The list of pytket element ids to map to wires. - /// * `params` - The list of parameters to load to wires. See [`WireTracker::load_parameter`] for more details. + /// # Arguments /// - // TODO: We'll need to be able to decompose types when we need only _some_ - // of the elements they contain (E.g., extract a value from an array), - // and do it automatically here. - pub(super) fn wire_inputs_for_command<'r>( - &mut self, - hugr: &mut FunctionBuilder<&mut Hugr>, - qubit_args: impl IntoIterator, - bit_args: impl IntoIterator, - params: impl IntoIterator, + /// * `config` - The configuration for the decoder, used to count the qubits and bits required by each type. + /// * `hugr` - The hugr to load the parameters to. + /// * `types` - The types of the arguments we require in the wires. + /// * `qubit_args` - The list of tracked qubits we require in the wires. + /// * `bit_args` - The list of tracked bits we require in the wire. + /// * `params` - The list of parameters to load to wires. See + /// [`WireTracker::load_parameter`] for more details. + pub(super) fn find_typed_wires<'r>( + &self, + type_translators: &TypeTranslatorSet, + types: &[Type], + qubit_args: impl IntoIterator, + bit_args: impl IntoIterator, + params: &[Arc], ) -> Result { // We need to return a set of wires that contain all the arguments. // // We collect this by checking the wires where each element is present, // and collecting them in order. - let mut qubit_args: VecDeque<(TrackedQubitId, &PytketRegister)> = qubit_args - .into_iter() - .map( - |register| match self.latest_qubit_tracker.get(&RegisterHash::from(register)) { - Some(id) => Ok((*id, register)), - None => Err(PytketDecodeError::unknown_qubit_reg(register)), - }, - ) - .collect::>()?; - let mut bit_args: VecDeque<(TrackedBitId, &PytketRegister)> = bit_args - .into_iter() - .map( - |register| match self.latest_bit_tracker.get(&RegisterHash::from(register)) { - Some(id) => Ok((*id, register)), - None => Err(PytketDecodeError::unknown_bit_reg(register)), - }, - ) - .collect::>()?; + let mut qubit_args: VecDeque<&TrackedQubit> = qubit_args.into_iter().collect(); + let mut bit_args: VecDeque<&TrackedBit> = bit_args.into_iter().collect(); - let mut value_wires = Vec::new(); - while !qubit_args.is_empty() || !bit_args.is_empty() { - // Check candidate wires that only contain the elements we need, in the right order. - // - // We may have leftover arguments at the end, which we'll try to - // get from another wire. - let filter_candidate_wire = |w: Wire| { - let wire_data = &self.wires[&w]; - let same_qubits = itertools::equal( - wire_data.qubits.iter(), - qubit_args - .iter() - .map(|(id, _)| id) - .take(wire_data.num_qubits()), - ); - let same_bits = itertools::equal( - wire_data.bits.iter(), - bit_args.iter().map(|(id, _)| id).take(wire_data.num_bits()), - ); - same_qubits && same_bits - }; - - let qubit_candidates = qubit_args - .front() - .into_iter() - .flat_map(|(id, _)| self.qubit_wires[id].iter()); - let bit_candidates = bit_args - .front() - .into_iter() - .flat_map(|(id, _)| self.bit_wires[id].iter()); - let candidate = qubit_candidates - .chain(bit_candidates) - .find(|&&w| filter_candidate_wire(w)); - - // If we found a candidate, add it to the list of wires. - match candidate { - Some(w) => { - // Consume the extracted args, and add the wire to the list. - let wire_data: WireData = self.wires[w].clone(); - qubit_args.drain(..wire_data.num_qubits()); - bit_args.drain(..wire_data.num_bits()); - value_wires.push(wire_data); - } - None => { - // In the future we may be able to decompose some wire containing `arg_ids[0]` internally. - // For now, we just report the error. - return Err(PytketDecodeErrorInner::ArgumentCouldNotBeMapped { + // Map each requested type to a wire. + // + // Ignore parameter inputs. + let param_types = [float64_type(), rotation_type()]; + let value_wires = types + .iter() + .filter(|ty| !param_types.contains(ty)) + .map(|ty| { + let Some(reg_count) = type_translators.type_to_pytket(ty) else { + return Err(PytketDecodeErrorInner::UnexpectedInputType { + unknown_type: ty.to_string(), + all_types: types.iter().map(ToString::to_string).collect(), + } + .wrap()); + }; + + // List candidate wires that contain the qubits and bits we need. + let qubit_candidates = qubit_args + .front() + .into_iter() + .flat_map(|qb| self.qubit_wires(qb)); + let bit_candidates = bit_args + .front() + .into_iter() + .flat_map(|bit| self.bit_wires(bit)); + let mut candidate = qubit_candidates.chain(bit_candidates); + + // Find a wire that contains the correct type.. + let check_wire = |w: &Wire| { + let wire_data = &self.wires[w]; + let qubits = qubit_args.iter().take(reg_count.qubits).map(|q| q.id()); + let bits = bit_args.iter().take(reg_count.bits).map(|bit| bit.id()); + wire_data.ty() == ty + && itertools::equal(wire_data.qubits.iter().copied(), qubits) + && itertools::equal(wire_data.bits.iter().copied(), bits) + }; + let Some(wire) = candidate.find(check_wire) else { + return Err(PytketDecodeErrorInner::NoMatchingWire { + ty: ty.to_string(), qubit_args: qubit_args .iter() - .map(|(_, elem)| elem.to_string()) + .map(|q| q.pytket_register().to_string()) + .collect(), + bit_args: bit_args + .iter() + .map(|bit| bit.pytket_register().to_string()) .collect(), - bit_args: bit_args.iter().map(|(_, elem)| elem.to_string()).collect(), } - .into()); - } - } - } + .wrap()); + }; - // Load the parameters. - let parameter_wires = params - .into_iter() - .map(|param| self.load_parameter(hugr, param)) - .collect_vec(); + // Mark the qubits and bits as used. + qubit_args.drain(..reg_count.qubits); + bit_args.drain(..reg_count.bits); + + Ok(self.wires[&wire].clone()) + }) + .collect::, _>>()?; Ok(TrackedWires { value_wires, - parameter_wires, + parameter_wires: params.to_vec(), }) } @@ -691,11 +691,17 @@ impl WireTracker { ) -> Result<(), PytketDecodeError> { let qubits = qubits .into_iter() - .map(|q| self.track_qubit(q.pytket_register_arc())) + .map(|q| { + self.track_qubit(q.pytket_register_arc(), None) + .map(TrackedQubit::id) + }) .collect::>()?; let bits = bits .into_iter() - .map(|b| self.track_bit(b.pytket_register_arc())) + .map(|b| { + self.track_bit(b.pytket_register_arc(), None) + .map(TrackedBit::id) + }) .collect::>()?; for &q in &qubits { @@ -774,7 +780,7 @@ mod tests { assert_eq!(tracker.known_pytket_bits().count(), 0); // Track an invalid register name. - match tracker.track_qubit(multi_indexed_reg.clone()) { + match tracker.track_qubit(multi_indexed_reg.clone(), None) { Err(PytketDecodeError { inner: PytketDecodeErrorInner::MultiIndexedRegister { register }, .. @@ -806,40 +812,45 @@ mod tests { // Track a new qubit let tracked_q_0 = tracker - .track_qubit(qubit_reg.clone()) - .expect("Should track qubit"); + .track_qubit(qubit_reg.clone(), None) + .expect("Should track qubit") + .clone(); assert_eq!(tracker.known_pytket_qubits().count(), 1); assert_eq!(tracker.known_pytket_bits().count(), 0); let tracked_qubit = tracker .tracked_qubit_for_register(&qubit_reg) - .expect("Should find tracked qubit"); + .expect("Should find tracked qubit") + .clone(); assert!(!tracked_qubit.is_outdated()); - assert_eq!(tracked_qubit, tracker.get_qubit(tracked_q_0)); + assert_eq!(tracked_qubit, tracked_q_0); // Track the same qubit again, it should add a new TrackedQubit and mark the previous one as outdated let tracked_q_1 = tracker - .track_qubit(qubit_reg.clone()) - .expect("Should track qubit again"); + .track_qubit(qubit_reg.clone(), None) + .expect("Should track qubit again") + .clone(); + let tracked_q_0 = tracker.get_qubit(tracked_q_0.id()); assert_eq!(tracker.known_pytket_qubits().count(), 1); // still only one unique register - assert!(tracker.get_qubit(tracked_q_0).is_outdated()); - assert!(!tracker.get_qubit(tracked_q_1).is_outdated()); + assert!(tracked_q_0.is_outdated()); + assert!(!tracked_q_1.is_outdated()); let tracked_qubit = tracker .tracked_qubit_for_register(&qubit_reg) .expect("Should find latest tracked qubit") .clone(); - assert_eq!(&tracked_qubit, tracker.get_qubit(tracked_q_1)); + assert_eq!(tracked_qubit, tracked_q_1); // Track a bit let bit_id = tracker - .track_bit(bit_reg.clone()) - .expect("Should track bit"); + .track_bit(bit_reg.clone(), None) + .expect("Should track bit") + .clone(); assert_eq!(tracker.known_pytket_bits().count(), 1); - assert!(!tracker.get_bit(bit_id).is_outdated()); + assert!(!bit_id.is_outdated()); let tracked_bit = tracker .tracked_bit_for_register(&bit_reg) .expect("Should find tracked bit") .clone(); - assert_eq!(&tracked_bit, tracker.get_bit(bit_id)); + assert_eq!(tracked_bit, bit_id); // Associate the bit and qubit with a wire. tracker From 44ab9927228b32db038298746a6985e1877f78fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Fri, 15 Aug 2025 10:02:07 +0100 Subject: [PATCH 10/11] doc typo --- tket/src/serialize/pytket/decoder/tracked_elem.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tket/src/serialize/pytket/decoder/tracked_elem.rs b/tket/src/serialize/pytket/decoder/tracked_elem.rs index 94255b68a..d8636c8cc 100644 --- a/tket/src/serialize/pytket/decoder/tracked_elem.rs +++ b/tket/src/serialize/pytket/decoder/tracked_elem.rs @@ -61,6 +61,7 @@ pub struct TrackedBit { /// The pytket register for this tracked element. reg: Arc, /// The hash of the pytket register for this tracked element, used to + /// speed up hashing and equality checks. reg_hash: RegisterHash, } From ae40c629fe274bae9e6e125907a0066197248352 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Fri, 15 Aug 2025 15:34:57 +0100 Subject: [PATCH 11/11] Add TODO --- tket/src/serialize/pytket/decoder/tracked_elem.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tket/src/serialize/pytket/decoder/tracked_elem.rs b/tket/src/serialize/pytket/decoder/tracked_elem.rs index d8636c8cc..eb2952286 100644 --- a/tket/src/serialize/pytket/decoder/tracked_elem.rs +++ b/tket/src/serialize/pytket/decoder/tracked_elem.rs @@ -62,6 +62,9 @@ pub struct TrackedBit { reg: Arc, /// The hash of the pytket register for this tracked element, used to /// speed up hashing and equality checks. + // + // TODO: We could put this along with `reg` in a `PytketResource` struct + // that gets used around the crate. reg_hash: RegisterHash, }