|
| 1 | +//! General tests. |
| 2 | +
|
| 3 | +use std::collections::{HashMap, HashSet}; |
| 4 | + |
| 5 | +use hugr::builder::{Dataflow, DataflowHugr, FunctionBuilder}; |
| 6 | +use hugr::extension::prelude::{bool_t, qb_t}; |
| 7 | + |
| 8 | +use hugr::types::Signature; |
| 9 | +use itertools::Itertools; |
| 10 | +use rstest::{fixture, rstest}; |
| 11 | +use tket::TketOp; |
| 12 | +use tket_json_rs::circuit_json::{self, SerialCircuit}; |
| 13 | +use tket_json_rs::register; |
| 14 | + |
| 15 | +use tket::circuit::Circuit; |
| 16 | +use tket::serialize::pytket::TKETDecode; |
| 17 | + |
| 18 | +use crate::extension::futures::FutureOpBuilder; |
| 19 | +use crate::extension::qsystem::QSystemOp; |
| 20 | +use crate::pytket::{qsystem_decoder_config, qsystem_encoder_config}; |
| 21 | + |
| 22 | +const NATIVE_GATES_JSON: &str = r#"{ |
| 23 | + "phase": "0", |
| 24 | + "bits": [], |
| 25 | + "qubits": [["q", [0]], ["q", [1]]], |
| 26 | + "commands": [ |
| 27 | + {"args": [["q", [0]], ["q", [1]]], "op": {"type": "ZZMax"}}, |
| 28 | + {"args": [["q", [0]], ["q", [1]]], "op": {"params": ["((pi) / (2)) / (pi)"], "type": "ZZPhase"}}, |
| 29 | + {"args":[["q",[0]]],"op":{"params":["(pi) / (3)", "beta"],"type":"PhasedX"}} |
| 30 | + ], |
| 31 | + "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]] |
| 32 | + }"#; |
| 33 | + |
| 34 | +/// Check some properties of the serial circuit. |
| 35 | +fn validate_serial_circ(circ: &SerialCircuit) { |
| 36 | + // Check that all commands have valid arguments. |
| 37 | + for command in &circ.commands { |
| 38 | + for arg in &command.args { |
| 39 | + assert!( |
| 40 | + circ.qubits.contains(®ister::Qubit::from(arg.clone())) |
| 41 | + || circ.bits.contains(®ister::Bit::from(arg.clone())), |
| 42 | + "Circuit command {command:?} has an invalid argument '{arg:?}'" |
| 43 | + ); |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + // Check that the implicit permutation is valid. |
| 48 | + let perm: HashMap<register::ElementId, register::ElementId> = circ |
| 49 | + .implicit_permutation |
| 50 | + .iter() |
| 51 | + .map(|p| (p.0.clone().id, p.1.clone().id)) |
| 52 | + .collect(); |
| 53 | + for (key, value) in &perm { |
| 54 | + let valid_qubits = circ.qubits.contains(®ister::Qubit::from(key.clone())) |
| 55 | + && circ.qubits.contains(®ister::Qubit::from(value.clone())); |
| 56 | + let valid_bits = circ.bits.contains(®ister::Bit::from(key.clone())) |
| 57 | + && circ.bits.contains(®ister::Bit::from(value.clone())); |
| 58 | + assert!( |
| 59 | + valid_qubits || valid_bits, |
| 60 | + "Circuit has an invalid permutation '{key:?} -> {value:?}'" |
| 61 | + ); |
| 62 | + } |
| 63 | + assert_eq!( |
| 64 | + perm.len(), |
| 65 | + circ.implicit_permutation.len(), |
| 66 | + "Circuit has duplicate permutations", |
| 67 | + ); |
| 68 | + assert_eq!( |
| 69 | + HashSet::<®ister::ElementId>::from_iter(perm.values()).len(), |
| 70 | + perm.len(), |
| 71 | + "Circuit has duplicate values in permutations" |
| 72 | + ); |
| 73 | +} |
| 74 | + |
| 75 | +fn compare_serial_circs(a: &SerialCircuit, b: &SerialCircuit) { |
| 76 | + assert_eq!(a.name, b.name); |
| 77 | + assert_eq!(a.phase, b.phase); |
| 78 | + assert_eq!(&a.qubits, &b.qubits); |
| 79 | + assert_eq!(a.commands.len(), b.commands.len()); |
| 80 | + |
| 81 | + let bits_a: HashSet<_> = a.bits.iter().collect(); |
| 82 | + let bits_b: HashSet<_> = b.bits.iter().collect(); |
| 83 | + assert_eq!(bits_a, bits_b); |
| 84 | + |
| 85 | + // We ignore the commands order here, as two encodings may swap |
| 86 | + // non-dependant operations. |
| 87 | + // |
| 88 | + // The correct thing here would be to run a deterministic toposort and |
| 89 | + // compare the commands in that order. This is just a quick check that |
| 90 | + // everything is present, ignoring wire dependencies. |
| 91 | + // |
| 92 | + // Another problem is that `Command`s cannot be compared directly; |
| 93 | + // - `command.op.signature`, and `n_qb` are optional and sometimes |
| 94 | + // unset in pytket-generated circs. |
| 95 | + // - qubit arguments names may differ if they have been allocated inside the circuit, |
| 96 | + // as they depend on the traversal argument. Same with classical params. |
| 97 | + // Here we define an ad-hoc subset that can be compared. |
| 98 | + // |
| 99 | + // TODO: Do a proper comparison independent of the toposort ordering, and |
| 100 | + // track register reordering. |
| 101 | + #[derive(PartialEq, Eq, Hash, Debug)] |
| 102 | + struct CommandInfo { |
| 103 | + op_type: tket_json_rs::OpType, |
| 104 | + params: Vec<String>, |
| 105 | + n_args: usize, |
| 106 | + } |
| 107 | + |
| 108 | + impl From<&tket_json_rs::circuit_json::Command> for CommandInfo { |
| 109 | + fn from(command: &tket_json_rs::circuit_json::Command) -> Self { |
| 110 | + let mut info = CommandInfo { |
| 111 | + op_type: command.op.op_type.clone(), |
| 112 | + params: command.op.params.clone().unwrap_or_default(), |
| 113 | + n_args: command.args.len(), |
| 114 | + }; |
| 115 | + |
| 116 | + // Special case for qsystem ops, where ZZMax does not exist. |
| 117 | + if command.op.op_type == tket_json_rs::OpType::ZZMax { |
| 118 | + info.op_type = tket_json_rs::OpType::ZZPhase; |
| 119 | + info.params = vec!["(pi) / (2)".to_string()]; |
| 120 | + } |
| 121 | + |
| 122 | + info |
| 123 | + } |
| 124 | + } |
| 125 | + |
| 126 | + let a_command_count: HashMap<CommandInfo, usize> = a.commands.iter().map_into().counts(); |
| 127 | + let b_command_count: HashMap<CommandInfo, usize> = b.commands.iter().map_into().counts(); |
| 128 | + |
| 129 | + for (a, &count_a) in &a_command_count { |
| 130 | + let count_b = b_command_count.get(a).copied().unwrap_or_default(); |
| 131 | + assert_eq!( |
| 132 | + count_a, count_b, |
| 133 | + "command {a:?} appears {count_a} times in rhs and {count_b} times in lhs" |
| 134 | + ); |
| 135 | + } |
| 136 | + assert_eq!(a_command_count.len(), b_command_count.len()); |
| 137 | +} |
| 138 | + |
| 139 | +/// A simple circuit with some qsystem operations. |
| 140 | +#[fixture] |
| 141 | +fn circ_qsystem_native_gates() -> Circuit { |
| 142 | + let input_t = vec![qb_t()]; |
| 143 | + let output_t = vec![qb_t(), bool_t(), bool_t()]; |
| 144 | + let mut h = |
| 145 | + FunctionBuilder::new("qsystem_native_gates", Signature::new(input_t, output_t)).unwrap(); |
| 146 | + |
| 147 | + let [qb0] = h.input_wires_arr(); |
| 148 | + let [qb1] = h.add_dataflow_op(TketOp::QAlloc, []).unwrap().outputs_arr(); |
| 149 | + |
| 150 | + let [future_bit_0] = h |
| 151 | + .add_dataflow_op(QSystemOp::LazyMeasure, [qb0]) |
| 152 | + .unwrap() |
| 153 | + .outputs_arr(); |
| 154 | + let [qb1, future_bit_1] = h |
| 155 | + .add_dataflow_op(QSystemOp::LazyMeasureReset, [qb1]) |
| 156 | + .unwrap() |
| 157 | + .outputs_arr(); |
| 158 | + |
| 159 | + let [bit_0] = h.add_read(future_bit_0, bool_t()).unwrap(); |
| 160 | + let [bit_1] = h.add_read(future_bit_1, bool_t()).unwrap(); |
| 161 | + |
| 162 | + let hugr = h.finish_hugr_with_outputs([qb1, bit_0, bit_1]).unwrap(); |
| 163 | + |
| 164 | + hugr.into() |
| 165 | +} |
| 166 | + |
| 167 | +#[rstest] |
| 168 | +#[case::native_gates(NATIVE_GATES_JSON, 3, 2)] |
| 169 | +fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num_qubits: usize) { |
| 170 | + let ser: circuit_json::SerialCircuit = serde_json::from_str(circ_s).unwrap(); |
| 171 | + assert_eq!(ser.commands.len(), num_commands); |
| 172 | + |
| 173 | + let circ: Circuit = ser |
| 174 | + .clone() |
| 175 | + .decode_with_config(qsystem_decoder_config()) |
| 176 | + .unwrap(); |
| 177 | + |
| 178 | + assert_eq!(circ.qubit_count(), num_qubits); |
| 179 | + |
| 180 | + let reser: SerialCircuit = |
| 181 | + SerialCircuit::encode_with_config(&circ, qsystem_encoder_config()).unwrap(); |
| 182 | + validate_serial_circ(&reser); |
| 183 | + compare_serial_circs(&ser, &reser); |
| 184 | +} |
| 185 | + |
| 186 | +/// Test the serialisation roundtrip from a tket circuit. |
| 187 | +/// |
| 188 | +/// Note: this is not a pure roundtrip as the encoder may add internal qubits/bits to the circuit. |
| 189 | +#[rstest] |
| 190 | +#[case::native_gates(circ_qsystem_native_gates(), Signature::new_endo(vec![qb_t(), qb_t(), bool_t(), bool_t()]))] |
| 191 | +fn circuit_roundtrip(#[case] circ: Circuit, #[case] decoded_sig: Signature) { |
| 192 | + let ser: SerialCircuit = |
| 193 | + SerialCircuit::encode_with_config(&circ, qsystem_encoder_config()).unwrap(); |
| 194 | + let deser: Circuit = ser |
| 195 | + .clone() |
| 196 | + .decode_with_config(qsystem_decoder_config()) |
| 197 | + .unwrap(); |
| 198 | + |
| 199 | + let deser_sig = deser.circuit_signature(); |
| 200 | + assert_eq!( |
| 201 | + &decoded_sig.input, &deser_sig.input, |
| 202 | + "Input signature mismatch\n Expected: {}\n Actual: {}", |
| 203 | + &decoded_sig, &deser_sig |
| 204 | + ); |
| 205 | + assert_eq!( |
| 206 | + &decoded_sig.output, &deser_sig.output, |
| 207 | + "Output signature mismatch\n Expected: {}\n Actual: {}", |
| 208 | + &decoded_sig, &deser_sig |
| 209 | + ); |
| 210 | + |
| 211 | + let reser = SerialCircuit::encode_with_config(&deser, qsystem_encoder_config()).unwrap(); |
| 212 | + validate_serial_circ(&reser); |
| 213 | + compare_serial_circs(&ser, &reser); |
| 214 | +} |
0 commit comments