Skip to content

feat: Define a wire tracker for the new pytket decoder #1036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 73 additions & 19 deletions tket/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use tket_json_rs::circuit_json::SerialCircuit;
use tket_json_rs::register::{Bit, ElementId, Qubit};

use crate::circuit::Circuit;
use crate::serialize::pytket::extension::RegisterCount;

use self::decoder::Tk1DecoderContext;
pub use crate::passes::pytket::lower_to_pytket;
Expand Down Expand Up @@ -358,19 +359,19 @@ pub enum PytketDecodeErrorInner {
/// The pytket circuit uses multi-indexed registers.
//
// This could be supported in the future, if there is a need for it.
#[display("Register {register} in the circuit has multiple indices. Tket2 does not support multi-indexed registers.")]
#[display("Register {register} in the circuit has multiple indices. Tket2 does not support multi-indexed registers")]
MultiIndexedRegister {
/// The register name.
register: String,
},
/// Found an unexpected register name.
#[display("Found an unknown qubit register name: {register}.")]
#[display("Found an unknown qubit register name: {register}")]
UnknownQubitRegister {
/// The unknown register name.
register: String,
},
/// Found an unexpected bit register name.
#[display("Found an unknown bit register name: {register}.")]
#[display("Found an unknown bit register name: {register}")]
UnknownBitRegister {
/// The unknown register name.
register: String,
Expand All @@ -379,7 +380,7 @@ pub enum PytketDecodeErrorInner {
///
/// The expected number of qubits and bits may be different depending on the [`PytketTypeTranslator`][extension::PytketTypeTranslator]s used in the decoder config.
#[display(
"The given input types {input_types} to use for the HUGR's input wires are not compatible with the number of qubits and bits in the pytket circuit. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits.",
"The given input types {input_types} to use for the HUGR's input wires are not compatible with the number of qubits and bits in the pytket circuit. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits",
input_types = input_types.iter().join(", "),
)]
InvalidInputSignature {
Expand All @@ -398,7 +399,7 @@ pub enum PytketDecodeErrorInner {
///
/// We don't do any kind of type conversion, so this depends solely on the last operation to update each register.
#[display(
"The expected output types {expected_types} are not compatible with the actual output types {actual_types}, obtained from decoding the pytket circuit.",
"The expected output types {expected_types} are not compatible with the actual output types {actual_types}, obtained from decoding the pytket circuit",
expected_types = expected_types.iter().join(", "),
actual_types = actual_types.iter().join(", "),
)]
Expand All @@ -412,7 +413,7 @@ pub enum PytketDecodeErrorInner {
//
// Some of this errors will be avoided in the future once we are able to decompose complex types automatically.
#[display(
"Could not find a wire with the required qubit arguments [{qubit_args:?}] and bit arguments [{bit_args:?}].",
"Could not find a wire with the required qubit arguments [{qubit_args:?}] and bit arguments [{bit_args:?}]",
qubit_args = qubit_args.iter().join(", "),
bit_args = bit_args.iter().join(", "),
)]
Expand All @@ -424,16 +425,16 @@ pub enum PytketDecodeErrorInner {
},
/// Found an unexpected number of input wires when decoding an operation.
#[display(
"Expected {expected_values} input value wires{expected_types} and {expected_params} input parameters, but found {actual_values} values{actual_types} and {actual_params} parameters.",
expected_types = match expected_types {
None => "".to_string(),
Some(tys) => format!(" with types [{}]", tys.iter().join(", ")),
},
actual_types = match actual_types {
None => "".to_string(),
Some(tys) => format!(" with types [{}]", tys.iter().join(", ")),
},
)]
"Expected {expected_values} input value wires{expected_types} and {expected_params} input parameters, but found {actual_values} values{actual_types} and {actual_params} parameters",
expected_types = match expected_types {
None => "".to_string(),
Some(tys) => format!(" with types [{}]", tys.iter().join(", ")),
},
actual_types = match actual_types {
None => "".to_string(),
Some(tys) => format!(" with types [{}]", tys.iter().join(", ")),
},
)]
UnexpectedInputWires {
/// The expected amount of input wires.
expected_values: usize,
Expand All @@ -448,9 +449,20 @@ pub enum PytketDecodeErrorInner {
/// The actual types of the input wires.
actual_types: Option<Vec<String>>,
},
/// Found an unexpected input type when decoding an operation.
#[display(
"Found an unexpected type {unknown_type} in the input wires, in input signature ({all_types})",
all_types = all_types.iter().join(", "),
)]
UnexpectedInputType {
/// The unknown type.
unknown_type: String,
/// All the input types specified for the operation.
all_types: Vec<String>,
},
/// Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires.
#[display(
"Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits in the node outputs."
"Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits in the node outputs"
)]
UnexpectedNodeOutput {
/// The expected number of qubits.
Expand All @@ -475,21 +487,63 @@ pub enum PytketDecodeErrorInner {
param: String,
},
/// Not enough parameter names given for the input signature.
#[display("Tried to initialize a pytket circuit decoder with {num_params_given} given parameter names, but more were required by the input signature.")]
#[display("Tried to initialize a pytket circuit decoder with {num_params_given} given parameter names, but more were required by the input signature")]
MissingParametersInInput {
/// The number of parameters given.
num_params_given: usize,
},
/// We don't support complex types containing parameters in the input.
//
// This restriction may be relaxed in the future.
#[display("Complex type {ty} contains {num_params} inside it. We only support input parameters in standalone 'float' or 'rotation'-typed wires.")]
#[display("Complex type {ty} contains {num_params} inside it. We only support input parameters in standalone 'float' or 'rotation'-typed wires")]
UnsupportedParametersInInput {
/// The type that contains the parameters.
ty: String,
/// The number of parameters in the type.
num_params: usize,
},
/// We couldn't find a wire that contains the required type.
#[display(
"Could not find a wire with type {ty} that contains {expected_arguments}",
expected_arguments = match (qubit_args.is_empty(), bit_args.is_empty()) {
(true, true) => "no arguments".to_string(),
(true, false) => format!("pytket bit arguments [{}]", bit_args.iter().join(", ")),
(false, true) => format!("pytket qubit arguments [{}]", qubit_args.iter().join(", ")),
(false, false) => format!("pytket qubit and bit arguments [{}] and [{}]", qubit_args.iter().join(", "), bit_args.iter().join(", ")),
},
)]
NoMatchingWire {
/// The type that couldn't be found.
ty: String,
/// The qubit registers expected in the wire.
qubit_args: Vec<String>,
/// The bit registers expected in the wire.
bit_args: Vec<String>,
},
/// The number of pytket registers expected for an operation is not enough.
///
/// This is usually caused by a mismatch between the input signature and the number of registers in the pytket circuit.
///
/// The expected number of registers may be different depending on the [`PytketTypeTranslator`][extension::PytketTypeTranslator]s used in the decoder config.
#[display(
"Expected {expected_count} to map types ({expected_types}), but only got {actual_count}",
expected_types = expected_types.iter().join(", "),
)]
NotEnoughPytketRegisters {
/// The types we tried to get wires for.
expected_types: Vec<String>,
/// The number of registers required by the types.
expected_count: RegisterCount,
/// The number of registers we actually got.
actual_count: RegisterCount,
},
}

impl PytketDecodeErrorInner {
/// Wrap the error in a [`PytketDecodeError`].
pub fn wrap(self) -> PytketDecodeError {
PytketDecodeError::from(self)
}
}

/// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map,
Expand Down
96 changes: 89 additions & 7 deletions tket/src/serialize/pytket/decoder/tracked_elem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
reason = "Temporarily unused while we refactor the pytket decoder"
)]

use std::hash::Hasher;
use std::sync::{Arc, LazyLock};

use hugr::extension::prelude::{bool_t, qb_t};
use hugr::types::Type;
use tket_json_rs::register::ElementId as PytketRegister;

use crate::serialize::pytket::RegisterHash;

/// An internal lightweight identifier for a [`TrackedQubit`] in the decoder.
#[derive(Clone, Copy, Debug, derive_more::Display, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub(super) struct TrackedQubitId(#[display(transparent)] pub usize);
Expand All @@ -25,10 +28,19 @@ pub(super) struct TrackedBitId(#[display(transparent)] pub usize);
///
/// Outdated values no longer correspond to a pytket circuit register, but they
/// can still be found in the wires of the hugr being extracted.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Eq, derive_more::Display)]
#[display("{reg}")]
pub struct TrackedQubit {
/// The id of this tracked qubit in the [`WireTracker`].
id: TrackedQubitId,
/// Whether this tracked qubit is outdated, meaning that we have seen the
/// register in a newer wire.
outdated: bool,
/// The pytket register for this tracked element.
reg: Arc<PytketRegister>,
/// The hash of the pytket register for this tracked element, used to
/// speed up hashing and equality checks.
reg_hash: RegisterHash,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally these two should be in a struct by themselves - they're not independent

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally intended to put this inside the Arc, but RegisterHash is a single byte so it's cheaper to just have it here.
I guess we could add an intermediary

pub struct PytketResource {
    element: Arc<tket_json_rs::register::ElementId>,
    hash: RegisterHash,
}

I'll add a TODO

}

/// An identifier for a pytket bit register in the data carried by a wire.
Expand All @@ -38,18 +50,38 @@ pub struct TrackedQubit {
///
/// Outdated values no longer correspond to a pytket circuit register, but they
/// can still be found in the wires of the hugr being extracted.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Eq, derive_more::Display)]
#[display("{reg}")]
pub struct TrackedBit {
/// The id of this tracked bit in the [`WireTracker`].
id: TrackedBitId,
/// Whether this tracked bit is outdated, meaning that we have seen the
/// register in a newer wire.
outdated: bool,
/// The pytket register for this tracked element.
reg: Arc<PytketRegister>,
/// The hash of the pytket register for this tracked element, used to
reg_hash: RegisterHash,
}

impl TrackedQubit {
/// Return a new tracked qubit.
pub(super) fn new(reg: Arc<PytketRegister>) -> Self {
/// Returns a new tracked qubit.
pub(super) fn new(id: TrackedQubitId, reg: Arc<PytketRegister>) -> Self {
let reg_hash = RegisterHash::from(reg.as_ref());
Self::new_with_hash(id, reg, reg_hash)
}

/// Returns a new tracked qubit.
pub(super) fn new_with_hash(
id: TrackedQubitId,
reg: Arc<PytketRegister>,
reg_hash: RegisterHash,
) -> Self {
Self {
id,
outdated: false,
reg,
reg_hash,
}
}

Expand All @@ -69,6 +101,11 @@ impl TrackedQubit {
QUBIT_TYPE.clone()
}

/// Returns the id of this tracked qubit.
pub(super) fn id(&self) -> TrackedQubitId {
self.id
}

/// Returns `true` if the element has been overwritten by a new value.
pub fn is_outdated(&self) -> bool {
self.outdated
Expand All @@ -82,10 +119,22 @@ impl TrackedQubit {

impl TrackedBit {
/// Returns a new tracked bit.
pub(super) fn new(reg: Arc<PytketRegister>) -> Self {
pub(super) fn new(id: TrackedBitId, reg: Arc<PytketRegister>) -> Self {
let reg_hash = RegisterHash::from(reg.as_ref());
Self::new_with_hash(id, reg, reg_hash)
}

/// Returns a new tracked bit.
pub(super) fn new_with_hash(
id: TrackedBitId,
reg: Arc<PytketRegister>,
reg_hash: RegisterHash,
) -> Self {
Self {
id,
outdated: false,
reg,
reg_hash,
}
}

Expand All @@ -105,6 +154,11 @@ impl TrackedBit {
BOOL_TYPE.clone()
}

/// Returns the id of this tracked bit.
pub(super) fn id(&self) -> TrackedBitId {
self.id
}

/// Returns `true` if the element has been overwritten by a new value.
pub fn is_outdated(&self) -> bool {
self.outdated
Expand All @@ -116,6 +170,34 @@ impl TrackedBit {
}
}

impl PartialEq for TrackedQubit {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.outdated == other.outdated && self.reg_hash == other.reg_hash
}
}

impl PartialEq for TrackedBit {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.outdated == other.outdated && self.reg_hash == other.reg_hash
}
}

impl std::hash::Hash for TrackedQubit {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.outdated.hash(state);
self.reg_hash.hash(state);
}
}

impl std::hash::Hash for TrackedBit {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.outdated.hash(state);
self.reg_hash.hash(state);
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -128,7 +210,7 @@ mod tests {
#[rstest]
fn tracked_qubit_basic_behaviour() {
let reg = Arc::new(ElementId("q".to_string(), vec![0]));
let mut tq = TrackedQubit::new(reg.clone());
let mut tq = TrackedQubit::new(TrackedQubitId(0), reg.clone());

assert!(!tq.is_outdated());
assert_eq!(tq.pytket_register(), &*reg);
Expand All @@ -142,7 +224,7 @@ mod tests {
#[rstest]
fn tracked_bit_basic_behaviour() {
let reg = Arc::new(ElementId("c".to_string(), vec![1]));
let mut tb = TrackedBit::new(reg.clone());
let mut tb = TrackedBit::new(TrackedBitId(0), reg.clone());

assert!(!tb.is_outdated());
assert_eq!(tb.pytket_register(), &*reg);
Expand Down
Loading
Loading