Skip to content

Commit 420e3ac

Browse files
committed
feat: Define a wire tracker for the new pytket decoder
1 parent 251fe9a commit 420e3ac

File tree

4 files changed

+1186
-0
lines changed

4 files changed

+1186
-0
lines changed

tket-py/src/circuit.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ create_py_exception!(
8181
"Error type for the conversion between tket and tket1 operations."
8282
);
8383

84+
create_py_exception!(
85+
tket::serialize::pytket::Tk1DecodeError,
86+
PyTK1DecodeError,
87+
"Error type for the conversion between tket1 and tket operations."
88+
);
89+
8490
/// Run the validation checks on a circuit.
8591
#[pyfunction]
8692
pub fn validate_circuit(c: &Bound<PyAny>) -> PyResult<()> {

tket/src/serialize/pytket.rs

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,174 @@ impl<N> Tk1ConvertError<N> {
279279
}
280280
}
281281

282+
/// Error type for conversion between tket2 ops and pytket operations.
283+
#[derive(derive_more::Debug, Display, Error)]
284+
#[non_exhaustive]
285+
pub enum Tk1DecodeError {
286+
/// The pytket circuit uses multi-indexed registers.
287+
//
288+
// This could be supported in the future, if there is a need for it.
289+
#[display("Register {register} in the circuit has multiple indices. Tket2 does not support multi-indexed registers.")]
290+
MultiIndexedRegister {
291+
/// The register name.
292+
register: String,
293+
},
294+
/// Found an unexpected register name.
295+
#[display("Found an unknown qubit register name: {register}.")]
296+
UnknownQubitRegister {
297+
/// The unknown register name.
298+
register: String,
299+
},
300+
/// Found an unexpected bit register name.
301+
#[display("Found an unknown bit register name: {register}.")]
302+
UnknownBitRegister {
303+
/// The unknown register name.
304+
register: String,
305+
},
306+
/// The given signature to use for the HUGR's input wires is not compatible with the number of qubits and bits in the pytket circuit.
307+
///
308+
/// The expected number of qubits and bits may be different depending on the [`PytketTypeTranslator`][extension::PytketTypeTranslator]s used in the decoder config.
309+
#[display(
310+
"The given input types {input_types} to use for the HUGR's input wires are not compatible with the number of qubits and bits in the pytket circuit. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits.",
311+
input_types = input_types.iter().join(", "),
312+
)]
313+
InvalidInputSignature {
314+
/// The given input types.
315+
input_types: Vec<String>,
316+
/// The expected number of qubits in the signature.
317+
expected_qubits: usize,
318+
/// The expected number of bits in the signature.
319+
expected_bits: usize,
320+
/// The number of qubits in the pytket circuit.
321+
circ_qubits: usize,
322+
/// The number of bits in the pytket circuit.
323+
circ_bits: usize,
324+
},
325+
/// The signature to use for the HUGR's output wires is not compatible with the number of qubits and bits in the pytket circuit.
326+
///
327+
/// We don't do any kind of type conversion, so this depends solely on the last operation to update each register.
328+
#[display(
329+
"The expected output types {expected_types} are not compatible with the actual output types {actual_types}, obtained from decoding the pytket circuit.",
330+
expected_types = expected_types.iter().join(", "),
331+
actual_types = actual_types.iter().join(", "),
332+
)]
333+
InvalidOutputSignature {
334+
/// The expected types of the input wires.
335+
expected_types: Vec<String>,
336+
/// The actual types of the input wires.
337+
actual_types: Vec<String>,
338+
},
339+
/// A pytket operation had some input registers that couldn't be mapped to hugr wires.
340+
//
341+
// Some of this errors will be avoided in the future once we are able to decompose complex types automatically.
342+
#[display(
343+
"Could not find a wire with the required qubit arguments [{qubit_args:?}] and bit arguments [{bit_args:?}] for operation {operation}.",
344+
qubit_args = qubit_args.iter().join(", "),
345+
bit_args = bit_args.iter().join(", "),
346+
)]
347+
ArgumentCouldNotBeMapped {
348+
/// The operation type that was being decoded.
349+
operation: String,
350+
/// The qubit arguments that couldn't be mapped.
351+
qubit_args: Vec<String>,
352+
/// The bit arguments that couldn't be mapped.
353+
bit_args: Vec<String>,
354+
},
355+
/// Found an unexpected number of input wires when decoding an operation.
356+
#[display(
357+
"Expected {expected_values} input value wires{expected_types} and {expected_params} input parameters when decoding a {operation}, but found {actual_values} values{actual_types} and {actual_params} parameters.",
358+
expected_types = match expected_types {
359+
None => "".to_string(),
360+
Some(tys) => format!(" with types [{}]", tys.iter().join(", ")),
361+
},
362+
actual_types = match actual_types {
363+
None => "".to_string(),
364+
Some(tys) => format!(" with types [{}]", tys.iter().join(", ")),
365+
},
366+
)]
367+
UnexpectedInputWires {
368+
/// The expected amount of input wires.
369+
expected_values: usize,
370+
/// The expected amount of input parameters.
371+
expected_params: usize,
372+
/// The actual amount of input wires.
373+
actual_values: usize,
374+
/// The actual amount of input parameters.
375+
actual_params: usize,
376+
/// The expected types of the input wires.
377+
expected_types: Option<Vec<String>>,
378+
/// The actual types of the input wires.
379+
actual_types: Option<Vec<String>>,
380+
/// The operation type that was being decoded.
381+
operation: String,
382+
},
383+
/// Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires.
384+
#[display(
385+
"Tried to track the output wires of a node, but the number of tracked elements didn't match the ones in the output wires. Expected {expected_qubits} qubits and {expected_bits} bits, but found {circ_qubits} qubits and {circ_bits} bits in the node outputs."
386+
)]
387+
UnexpectedNodeOutput {
388+
/// The expected number of qubits.
389+
expected_qubits: usize,
390+
/// The expected number of bits.
391+
expected_bits: usize,
392+
/// The number of qubits in HUGR node outputs.
393+
circ_qubits: usize,
394+
/// The number of bits in HUGR node output.
395+
circ_bits: usize,
396+
},
397+
/// Custom user-defined error raised while encoding an operation.
398+
#[display("Error while decoding operation: {msg}")]
399+
CustomError {
400+
/// The custom error message
401+
msg: String,
402+
},
403+
/// Input parameter was defined multiple times.
404+
#[display("Parameter {param} was defined multiple times in the input signature")]
405+
DuplicatedParameter {
406+
/// The parameter name.
407+
param: String,
408+
},
409+
/// Not enough parameter names given for the input signature.
410+
#[display("Tried to initialize a pytket circuit decoder with {num_params_given} given parameter names, but more were required by the input signature.")]
411+
MissingParametersInInput {
412+
/// The number of parameters given.
413+
num_params_given: usize,
414+
},
415+
/// We don't support complex types containing parameters in the input.
416+
//
417+
// This restriction may be relaxed in the future.
418+
#[display("Complex type {ty} contains {num_params} inside it. We only support input parameters in standalone 'float' or 'rotation'-typed wires.")]
419+
UnsupportedParametersInInput {
420+
/// The type that contains the parameters.
421+
ty: String,
422+
/// The number of parameters in the type.
423+
num_params: usize,
424+
},
425+
}
426+
427+
impl Tk1DecodeError {
428+
/// Create a new error with a custom message.
429+
pub fn custom(msg: impl ToString) -> Self {
430+
Self::CustomError {
431+
msg: msg.to_string(),
432+
}
433+
}
434+
435+
/// Create an error for an unknown qubit register.
436+
pub fn unknown_qubit_reg(register: &tket_json_rs::register::ElementId) -> Self {
437+
Self::UnknownQubitRegister {
438+
register: register.to_string(),
439+
}
440+
}
441+
442+
/// Create an error for an unknown bit register.
443+
pub fn unknown_bit_reg(register: &tket_json_rs::register::ElementId) -> Self {
444+
Self::UnknownBitRegister {
445+
register: register.to_string(),
446+
}
447+
}
448+
}
449+
282450
/// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map,
283451
/// avoiding string and vector clones on lookup.
284452
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
//! Pytket qubit and bit elements that we track during decoding.
2+
3+
use std::sync::{Arc, LazyLock};
4+
5+
use hugr::extension::prelude::{bool_t, qb_t};
6+
use hugr::types::Type;
7+
use tket_json_rs::register::ElementId as PytketRegister;
8+
9+
/// An internal lightweight identifier for a [`TrackedQubit`] in the decoder.
10+
#[derive(Clone, Copy, Debug, derive_more::Display, Hash, PartialEq, Eq, PartialOrd, Ord)]
11+
pub(super) struct TrackedQubitId(#[display(transparent)] pub usize);
12+
13+
/// An internal lightweight identifier for a [`TrackedBit`] in the decoder.
14+
#[derive(Clone, Copy, Debug, derive_more::Display, Hash, PartialEq, Eq, PartialOrd, Ord)]
15+
pub(super) struct TrackedBitId(#[display(transparent)] pub usize);
16+
17+
/// An identifier for a pytket qubit register in the data carried by a wire.
18+
///
19+
/// After a pytket circuit assigns a new value to the register, older
20+
/// [`TrackedQubit`]s referring to it become _outdated_.
21+
///
22+
/// Outdated values no longer correspond to a pytket circuit register, but they
23+
/// can still be found in the wires of the hugr being extracted.
24+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25+
pub struct TrackedQubit {
26+
outdated: bool,
27+
reg: Arc<PytketRegister>,
28+
}
29+
30+
/// An identifier for a pytket bit register in the data carried by a wire.
31+
///
32+
/// After a pytket circuit assigns a new value to the register, older
33+
/// [`TrackedBit`]s referring to it become _outdated_.
34+
///
35+
/// Outdated values no longer correspond to a pytket circuit register, but they
36+
/// can still be found in the wires of the hugr being extracted.
37+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38+
pub struct TrackedBit {
39+
outdated: bool,
40+
reg: Arc<PytketRegister>,
41+
}
42+
43+
impl TrackedQubit {
44+
/// Return a new tracked qubit.
45+
pub(super) fn new(reg: Arc<PytketRegister>) -> Self {
46+
Self {
47+
outdated: false,
48+
reg,
49+
}
50+
}
51+
52+
/// Returns the pytket register for this tracked element.
53+
pub fn pytket_register(&self) -> &PytketRegister {
54+
&self.reg
55+
}
56+
57+
/// Returns the pytket register for this tracked element.
58+
pub fn pytket_register_arc(&self) -> Arc<PytketRegister> {
59+
self.reg.clone()
60+
}
61+
62+
/// Returns the type of the element.
63+
pub fn ty(&self) -> Arc<Type> {
64+
static QUBIT_TYPE: LazyLock<Arc<Type>> = LazyLock::new(|| qb_t().into());
65+
QUBIT_TYPE.clone()
66+
}
67+
68+
/// Returns `true` if the element has been overwritten by a new value.
69+
pub fn is_outdated(&self) -> bool {
70+
self.outdated
71+
}
72+
73+
/// Mark the element as outdated.
74+
pub(super) fn mark_outdated(&mut self) {
75+
self.outdated = true;
76+
}
77+
}
78+
79+
impl TrackedBit {
80+
/// Returns a new tracked bit.
81+
pub(super) fn new(reg: Arc<PytketRegister>) -> Self {
82+
Self {
83+
outdated: false,
84+
reg,
85+
}
86+
}
87+
88+
/// Returns the pytket register for this tracked element.
89+
pub fn pytket_register(&self) -> &PytketRegister {
90+
&self.reg
91+
}
92+
93+
/// Returns the pytket register for this tracked element.
94+
pub fn pytket_register_arc(&self) -> Arc<PytketRegister> {
95+
self.reg.clone()
96+
}
97+
98+
/// Returns the type of the element.
99+
pub fn ty(&self) -> Arc<Type> {
100+
static BOOL_TYPE: LazyLock<Arc<Type>> = LazyLock::new(|| bool_t().into());
101+
BOOL_TYPE.clone()
102+
}
103+
104+
/// Returns `true` if the element has been overwritten by a new value.
105+
pub fn is_outdated(&self) -> bool {
106+
self.outdated
107+
}
108+
109+
/// Mark the element as outdated.
110+
pub(super) fn mark_outdated(&mut self) {
111+
self.outdated = true;
112+
}
113+
}
114+
115+
#[cfg(test)]
116+
mod tests {
117+
use super::*;
118+
use hugr::extension::prelude::{bool_t, qb_t};
119+
use hugr::types::Type;
120+
use rstest::rstest;
121+
use std::sync::Arc;
122+
use tket_json_rs::register::ElementId;
123+
124+
#[rstest]
125+
fn tracked_qubit_basic_behaviour() {
126+
let reg = Arc::new(ElementId("q".to_string(), vec![0]));
127+
let mut tq = TrackedQubit::new(reg.clone());
128+
129+
assert!(!tq.is_outdated());
130+
assert_eq!(tq.pytket_register(), &*reg);
131+
assert_eq!(tq.pytket_register_arc(), reg);
132+
assert_eq!(&*tq.ty(), &Type::from(qb_t()));
133+
134+
tq.mark_outdated();
135+
assert!(tq.is_outdated());
136+
}
137+
138+
#[rstest]
139+
fn tracked_bit_basic_behaviour() {
140+
let reg = Arc::new(ElementId("c".to_string(), vec![1]));
141+
let mut tb = TrackedBit::new(reg.clone());
142+
143+
assert!(!tb.is_outdated());
144+
assert_eq!(tb.pytket_register(), &*reg);
145+
assert_eq!(tb.pytket_register_arc(), reg);
146+
assert_eq!(&*tb.ty(), &Type::from(bool_t()));
147+
148+
tb.mark_outdated();
149+
assert!(tb.is_outdated());
150+
}
151+
}

0 commit comments

Comments
 (0)