Skip to content

feat: Define a wire tracker for the new pytket decoder #1036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tket-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ create_py_exception!(
"Error type for the conversion between tket and tket1 operations."
);

create_py_exception!(
tket::serialize::pytket::Tk1DecodeError,
PyTK1DecodeError,
"Error type for the conversion between tket1 and tket operations."
);

/// Run the validation checks on a circuit.
#[pyfunction]
pub fn validate_circuit(c: &Bound<PyAny>) -> PyResult<()> {
Expand Down
168 changes: 168 additions & 0 deletions tket/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,174 @@ impl<N> Tk1ConvertError<N> {
}
}

/// Error type for conversion between tket2 ops and pytket operations.
#[derive(derive_more::Debug, Display, Error)]
#[non_exhaustive]
pub enum Tk1DecodeError {
/// The pytket circuit uses multi-indexed registers.
//
// This could be supported in the future, if there is a need for it.
#[display("Register {register} in the circuit has multiple indices. Tket2 does not support multi-indexed registers.")]
MultiIndexedRegister {
/// The register name.
register: String,
},
/// Found an unexpected register name.
#[display("Found an unknown qubit register name: {register}.")]
UnknownQubitRegister {
/// The unknown register name.
register: String,
},
/// Found an unexpected bit register name.
#[display("Found an unknown bit register name: {register}.")]
UnknownBitRegister {
/// The unknown register name.
register: String,
},
/// The given signature to use for the HUGR's input wires is not compatible with the number of qubits and bits in the pytket circuit.
///
/// The expected number of qubits and bits may be different depending on the [`PytketTypeTranslator`][extension::PytketTypeTranslator]s used in the decoder config.
#[display(
"The given input types {input_types} to use for the HUGR's input wires are not compatible with the number of qubits and bits in the pytket circuit. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits.",
input_types = input_types.iter().join(", "),
)]
InvalidInputSignature {
/// The given input types.
input_types: Vec<String>,
/// The expected number of qubits in the signature.
expected_qubits: usize,
/// The expected number of bits in the signature.
expected_bits: usize,
/// The number of qubits in the pytket circuit.
circ_qubits: usize,
/// The number of bits in the pytket circuit.
circ_bits: usize,
},
/// The signature to use for the HUGR's output wires is not compatible with the number of qubits and bits in the pytket circuit.
///
/// We don't do any kind of type conversion, so this depends solely on the last operation to update each register.
#[display(
"The expected output types {expected_types} are not compatible with the actual output types {actual_types}, obtained from decoding the pytket circuit.",
expected_types = expected_types.iter().join(", "),
actual_types = actual_types.iter().join(", "),
)]
InvalidOutputSignature {
/// The expected types of the input wires.
expected_types: Vec<String>,
/// The actual types of the input wires.
actual_types: Vec<String>,
},
/// A pytket operation had some input registers that couldn't be mapped to hugr wires.
//
// Some of this errors will be avoided in the future once we are able to decompose complex types automatically.
#[display(
"Could not find a wire with the required qubit arguments [{qubit_args:?}] and bit arguments [{bit_args:?}] for operation {operation}.",
qubit_args = qubit_args.iter().join(", "),
bit_args = bit_args.iter().join(", "),
)]
ArgumentCouldNotBeMapped {
/// The operation type that was being decoded.
operation: String,
/// The qubit arguments that couldn't be mapped.
qubit_args: Vec<String>,
/// The bit arguments that couldn't be mapped.
bit_args: Vec<String>,
},
/// Found an unexpected number of input wires when decoding an operation.
#[display(
"Expected {expected_values} input value wires{expected_types} and {expected_params} input parameters when decoding a {operation}, but found {actual_values} values{actual_types} and {actual_params} parameters.",
expected_types = match expected_types {
None => "".to_string(),
Some(tys) => format!(" with types [{}]", tys.iter().join(", ")),
},
actual_types = match actual_types {
None => "".to_string(),
Some(tys) => format!(" with types [{}]", tys.iter().join(", ")),
},
)]
UnexpectedInputWires {
/// The expected amount of input wires.
expected_values: usize,
/// The expected amount of input parameters.
expected_params: usize,
/// The actual amount of input wires.
actual_values: usize,
/// The actual amount of input parameters.
actual_params: usize,
/// The expected types of the input wires.
expected_types: Option<Vec<String>>,
/// The actual types of the input wires.
actual_types: Option<Vec<String>>,
/// The operation type that was being decoded.
operation: String,
},
/// Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires.
#[display(
"Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits in the node outputs."
)]
UnexpectedNodeOutput {
/// The expected number of qubits.
expected_qubits: usize,
/// The expected number of bits.
expected_bits: usize,
/// The number of qubits in HUGR node outputs.
circ_qubits: usize,
/// The number of bits in HUGR node output.
circ_bits: usize,
},
/// Custom user-defined error raised while encoding an operation.
#[display("Error while decoding operation: {msg}")]
CustomError {
/// The custom error message
msg: String,
},
/// Input parameter was defined multiple times.
#[display("Parameter {param} was defined multiple times in the input signature")]
DuplicatedParameter {
/// The parameter name.
param: String,
},
/// Not enough parameter names given for the input signature.
#[display("Tried to initialize a pytket circuit decoder with {num_params_given} given parameter names, but more were required by the input signature.")]
MissingParametersInInput {
/// The number of parameters given.
num_params_given: usize,
},
/// We don't support complex types containing parameters in the input.
//
// This restriction may be relaxed in the future.
#[display("Complex type {ty} contains {num_params} inside it. We only support input parameters in standalone 'float' or 'rotation'-typed wires.")]
UnsupportedParametersInInput {
/// The type that contains the parameters.
ty: String,
/// The number of parameters in the type.
num_params: usize,
},
}

impl Tk1DecodeError {
/// Create a new error with a custom message.
pub fn custom(msg: impl ToString) -> Self {
Self::CustomError {
msg: msg.to_string(),
}
}

/// Create an error for an unknown qubit register.
pub fn unknown_qubit_reg(register: &tket_json_rs::register::ElementId) -> Self {
Self::UnknownQubitRegister {
register: register.to_string(),
}
}

/// Create an error for an unknown bit register.
pub fn unknown_bit_reg(register: &tket_json_rs::register::ElementId) -> Self {
Self::UnknownBitRegister {
register: register.to_string(),
}
}
}

/// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map,
/// avoiding string and vector clones on lookup.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
Expand Down
9 changes: 9 additions & 0 deletions tket/src/serialize/pytket/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@

mod op;
mod param;
mod tracked_elem;
mod wires;

#[expect(
unused_imports,
reason = "Temporarily unused while we refactor the pytket decoder"
)]
pub use param::{LoadedParameter, LoadedParameterType};
pub use tracked_elem::{TrackedBit, TrackedQubit};
#[expect(
unused_imports,
reason = "Temporarily unused while we refactor the pytket decoder"
)]
pub use wires::{TrackedWires, WireData, WireTracker};

use std::collections::{HashMap, HashSet};

Expand All @@ -19,6 +27,7 @@ use hugr::ops::{OpType, Value};
use hugr::std_extensions::arithmetic::float_types::ConstF64;
use hugr::types::Signature;
use hugr::{Hugr, Wire};
use tracked_elem::{TrackedBitId, TrackedQubitId};

use indexmap::IndexMap;
use itertools::{EitherOrBoth, Itertools};
Expand Down
4 changes: 0 additions & 4 deletions tket/src/serialize/pytket/decoder/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ impl LoadedParameter {
}

/// Returns the hugr type for the parameter.
#[expect(
dead_code,
reason = "Temporarily unused while we refactor the pytket decoder"
)]
pub fn wire_type(&self) -> &Type {
static FLOAT_TYPE: LazyLock<Type> = LazyLock::new(float64_type);
static ROTATION_TYPE: LazyLock<Type> = LazyLock::new(rotation_type);
Expand Down
155 changes: 155 additions & 0 deletions tket/src/serialize/pytket/decoder/tracked_elem.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
//! Pytket qubit and bit elements that we track during decoding.
#![allow(
dead_code,
reason = "Temporarily unused while we refactor the pytket decoder"
)]

use std::sync::{Arc, LazyLock};

use hugr::extension::prelude::{bool_t, qb_t};
use hugr::types::Type;
use tket_json_rs::register::ElementId as PytketRegister;

/// An internal lightweight identifier for a [`TrackedQubit`] in the decoder.
#[derive(Clone, Copy, Debug, derive_more::Display, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub(super) struct TrackedQubitId(#[display(transparent)] pub usize);

/// An internal lightweight identifier for a [`TrackedBit`] in the decoder.
#[derive(Clone, Copy, Debug, derive_more::Display, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub(super) struct TrackedBitId(#[display(transparent)] pub usize);

/// An identifier for a pytket qubit register in the data carried by a wire.
///
/// After a pytket circuit assigns a new value to the register, older
/// [`TrackedQubit`]s referring to it become _outdated_.
///
/// Outdated values no longer correspond to a pytket circuit register, but they
/// can still be found in the wires of the hugr being extracted.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TrackedQubit {
outdated: bool,
reg: Arc<PytketRegister>,
}

/// An identifier for a pytket bit register in the data carried by a wire.
///
/// After a pytket circuit assigns a new value to the register, older
/// [`TrackedBit`]s referring to it become _outdated_.
///
/// Outdated values no longer correspond to a pytket circuit register, but they
/// can still be found in the wires of the hugr being extracted.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TrackedBit {
outdated: bool,
reg: Arc<PytketRegister>,
}

impl TrackedQubit {
/// Return a new tracked qubit.
pub(super) fn new(reg: Arc<PytketRegister>) -> Self {
Self {
outdated: false,
reg,
}
}

/// Returns the pytket register for this tracked element.
pub fn pytket_register(&self) -> &PytketRegister {
&self.reg
}

/// Returns the pytket register for this tracked element.
pub fn pytket_register_arc(&self) -> Arc<PytketRegister> {
self.reg.clone()
}

/// Returns the type of the element.
pub fn ty(&self) -> Arc<Type> {
static QUBIT_TYPE: LazyLock<Arc<Type>> = LazyLock::new(|| qb_t().into());
QUBIT_TYPE.clone()
}

/// Returns `true` if the element has been overwritten by a new value.
pub fn is_outdated(&self) -> bool {
self.outdated
}

/// Mark the element as outdated.
pub(super) fn mark_outdated(&mut self) {
self.outdated = true;
}
}

impl TrackedBit {
/// Returns a new tracked bit.
pub(super) fn new(reg: Arc<PytketRegister>) -> Self {
Self {
outdated: false,
reg,
}
}

/// Returns the pytket register for this tracked element.
pub fn pytket_register(&self) -> &PytketRegister {
&self.reg
}

/// Returns the pytket register for this tracked element.
pub fn pytket_register_arc(&self) -> Arc<PytketRegister> {
self.reg.clone()
}

/// Returns the type of the element.
pub fn ty(&self) -> Arc<Type> {
static BOOL_TYPE: LazyLock<Arc<Type>> = LazyLock::new(|| bool_t().into());
BOOL_TYPE.clone()
}

/// Returns `true` if the element has been overwritten by a new value.
pub fn is_outdated(&self) -> bool {
self.outdated
}

/// Mark the element as outdated.
pub(super) fn mark_outdated(&mut self) {
self.outdated = true;
}
}

#[cfg(test)]
mod tests {
use super::*;
use hugr::extension::prelude::{bool_t, qb_t};
use hugr::types::Type;
use rstest::rstest;
use std::sync::Arc;
use tket_json_rs::register::ElementId;

#[rstest]
fn tracked_qubit_basic_behaviour() {
let reg = Arc::new(ElementId("q".to_string(), vec![0]));
let mut tq = TrackedQubit::new(reg.clone());

assert!(!tq.is_outdated());
assert_eq!(tq.pytket_register(), &*reg);
assert_eq!(tq.pytket_register_arc(), reg);
assert_eq!(&*tq.ty(), &Type::from(qb_t()));

tq.mark_outdated();
assert!(tq.is_outdated());
}

#[rstest]
fn tracked_bit_basic_behaviour() {
let reg = Arc::new(ElementId("c".to_string(), vec![1]));
let mut tb = TrackedBit::new(reg.clone());

assert!(!tb.is_outdated());
assert_eq!(tb.pytket_register(), &*reg);
assert_eq!(tb.pytket_register_arc(), reg);
assert_eq!(&*tb.ty(), &Type::from(bool_t()));

tb.mark_outdated();
assert!(tb.is_outdated());
}
}
Loading
Loading