Skip to content

Commit 0e55321

Browse files
authored
fix(qsystem): handle barrier lowering for all array kinds (#1024)
Closes #1023
1 parent a975c1d commit 0e55321

File tree

4 files changed

+176
-75
lines changed

4 files changed

+176
-75
lines changed

tket-qsystem/src/extension/qsystem/barrier.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ mod test {
88
use crate::extension::qsystem::{self, lower_tk2_op};
99
use hugr::builder::{Dataflow, DataflowHugr};
1010
use hugr::extension::prelude::Barrier;
11+
use hugr::std_extensions::collections::borrow_array::borrow_array_type;
12+
use hugr::std_extensions::collections::value_array::value_array_type;
1113
use hugr::{
1214
builder::DFGBuilder,
1315
extension::prelude::{bool_t, option_type, qb_t},
@@ -32,6 +34,8 @@ mod test {
3234
// bare option of qubit is ignored
3335
#[case(vec![qb_t(), option_type(qb_t()).into()], 1, false)]
3436
#[case(vec![array_type(2, bool_t())], 0, false)]
37+
#[case(vec![value_array_type(2, option_type(qb_t()).into())], 2, false)]
38+
#[case(vec![borrow_array_type(2, qb_t())], 2, false)]
3539
// special case, single array of qubits is passed directly to op without unpacking
3640
#[case(vec![array_type(3, qb_t())], 1, true)]
3741
#[case(vec![qb_t(), array_type(2, qb_t()), array_type(2, array_type(2, qb_t()))], 7, false)]
@@ -93,6 +97,7 @@ mod test {
9397
let mut analyzer = QTypeAnalyzer::new();
9498
let tuple_type = hugr::types::Type::new_tuple(type_row);
9599
assert!(!analyzer.is_qubit_container(&tuple_type));
100+
assert_eq!(num_qb, 0);
96101
return;
97102
}
98103
h.single_linked_input(run_barr_func_n.unwrap(), 0)

tket-qsystem/src/extension/qsystem/barrier/barrier_inserter.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ impl BarrierInserter {
6969
target: Target,
7070
) -> Option<Result<(), LowerTk2Error>> {
7171
// Check if this is an array of qubits
72-
let size = is_qubit_array(typ)?;
72+
let size = is_qubit_array::<hugr::std_extensions::collections::array::Array>(typ)?;
73+
74+
// TODO if other array type, convert
7375

7476
// Build and insert the barrier
7577
Some(match build_runtime_barrier_op(size) {

tket-qsystem/src/extension/qsystem/barrier/barrier_ops.rs

Lines changed: 142 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ use hugr::{
99
Extension,
1010
},
1111
ops::{DataflowOpTrait, ExtensionOp, OpName},
12-
std_extensions::collections::array::{array_type, array_type_parametric, ArrayOpBuilder},
12+
std_extensions::collections::{
13+
array::{array_type, op_builder::GenericArrayOpBuilder, Array, ArrayKind},
14+
borrow_array::BorrowArray,
15+
value_array::ValueArray,
16+
},
1317
types::{
1418
type_param::TypeParam, FuncValueType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg,
1519
TypeBound, TypeRV,
@@ -62,6 +66,43 @@ pub struct BarrierOperationFactory {
6266
type_analyzer: QTypeAnalyzer,
6367
}
6468

69+
fn generic_array_unpack_sig<AK: ArrayKind>() -> PolyFuncTypeRV {
70+
PolyFuncTypeRV::new(
71+
vec![
72+
TypeParam::max_nat_type(),
73+
TypeParam::RuntimeType(TypeBound::Linear),
74+
TypeParam::new_list_type(TypeBound::Linear),
75+
],
76+
FuncValueType::new(
77+
AK::ty_parametric(
78+
TypeArg::new_var_use(0, TypeParam::max_nat_type()),
79+
Type::new_var_use(1, TypeBound::Linear),
80+
)
81+
.unwrap(),
82+
TypeRV::new_row_var_use(2, TypeBound::Linear),
83+
),
84+
)
85+
}
86+
87+
/// Helper function to add array operations for any ArrayKind to the extension
88+
fn add_array_ops<AK: ArrayKind>(
89+
ext: &mut Extension,
90+
ext_ref: &std::sync::Weak<Extension>,
91+
unpack_name: OpName,
92+
repack_name: OpName,
93+
) -> Result<(), hugr::extension::ExtensionBuildError> {
94+
let array_unpack_sig = generic_array_unpack_sig::<AK>();
95+
// pack some wires into an array
96+
ext.add_op(
97+
repack_name,
98+
Default::default(),
99+
invert_sig(&array_unpack_sig),
100+
ext_ref,
101+
)?;
102+
// unpack an array into some wires
103+
ext.add_op(unpack_name, Default::default(), array_unpack_sig, ext_ref)?;
104+
Ok(())
105+
}
65106
impl BarrierOperationFactory {
66107
/// Temporary extension name.
67108
pub(super) const TEMP_EXT_NAME: hugr::hugr::IdentList =
@@ -72,6 +113,10 @@ impl BarrierOperationFactory {
72113
pub(super) const WRAPPED_BARRIER: OpName = OpName::new_static("wrapped_barrier");
73114
pub(super) const ARRAY_UNPACK: OpName = OpName::new_static("array_unpack");
74115
pub(super) const ARRAY_REPACK: OpName = OpName::new_static("array_repack");
116+
pub(super) const VARRAY_UNPACK: OpName = OpName::new_static("varray_unpack");
117+
pub(super) const VARRAY_REPACK: OpName = OpName::new_static("varray_repack");
118+
pub(super) const BARRAY_UNPACK: OpName = OpName::new_static("barray_unpack");
119+
pub(super) const BARRAY_REPACK: OpName = OpName::new_static("barray_repack");
75120
pub(super) const TUPLE_UNPACK: OpName = OpName::new_static("tuple_unpack");
76121
pub(super) const TUPLE_REPACK: OpName = OpName::new_static("tuple_repack");
77122

@@ -122,35 +167,17 @@ impl BarrierOperationFactory {
122167
ext_ref,
123168
)
124169
.unwrap();
125-
let array_unpack_sig = PolyFuncTypeRV::new(
126-
vec![
127-
TypeParam::max_nat_type(),
128-
TypeParam::RuntimeType(TypeBound::Linear),
129-
TypeParam::new_list_type(TypeBound::Linear),
130-
],
131-
FuncValueType::new(
132-
array_type_parametric(
133-
TypeArg::new_var_use(0, TypeParam::max_nat_type()),
134-
Type::new_var_use(1, TypeBound::Linear),
135-
)
136-
.unwrap(),
137-
TypeRV::new_row_var_use(2, TypeBound::Linear),
138-
),
139-
);
140-
// pack some wires into an array
141-
ext.add_op(
142-
Self::ARRAY_REPACK,
143-
Default::default(),
144-
invert_sig(&array_unpack_sig),
145-
ext_ref,
146-
)
147-
.unwrap();
148-
// unpack an array into some wires
149-
ext.add_op(
150-
Self::ARRAY_UNPACK,
151-
Default::default(),
152-
array_unpack_sig,
170+
171+
// Add array operations for all ArrayKind types
172+
add_array_ops::<Array>(ext, ext_ref, Self::ARRAY_UNPACK, Self::ARRAY_REPACK)
173+
.unwrap();
174+
add_array_ops::<ValueArray>(ext, ext_ref, Self::VARRAY_UNPACK, Self::VARRAY_REPACK)
175+
.unwrap();
176+
add_array_ops::<BorrowArray>(
177+
ext,
153178
ext_ref,
179+
Self::BARRAY_UNPACK,
180+
Self::BARRAY_REPACK,
154181
)
155182
.unwrap();
156183

@@ -303,46 +330,47 @@ impl BarrierOperationFactory {
303330
)
304331
}
305332

306-
/// Unpack an array into individual wires
307-
pub fn unpack_array(
333+
/// Generic array unpacking using the ArrayKind trait
334+
fn unpack_array<AK: ArrayKind>(
308335
&mut self,
309336
builder: &mut impl Dataflow,
310337
array_wire: Wire,
311338
size: u64,
312339
elem_ty: &Type,
340+
op_name: &OpName,
313341
) -> Result<Vec<Wire>, BuildError> {
314-
// Calculate arguments for the array unpacking
315-
let args = match self.array_args(size, elem_ty) {
342+
let args = match self.array_args::<AK>(size, elem_ty) {
316343
Some(args) => args,
317344
None => return Ok(vec![array_wire]), // Not a qubit-containing array
318345
};
319346

320347
let outputs = self.apply_cached_operation(
321348
builder,
322-
&Self::ARRAY_UNPACK,
349+
op_name,
323350
args.clone(),
324351
&args[..2],
325352
[array_wire],
326353
|slf, func_b| {
327354
let w = func_b.input().out_wire(0);
328-
let elems = func_b.add_array_unpack(elem_ty.clone(), size, w)?;
329-
let unpacked: Vec<_> = elems
355+
let elems = func_b.add_generic_array_unpack::<AK>(elem_ty.clone(), size, w)?;
356+
357+
let result: Vec<_> = elems
330358
.into_iter()
331359
.map(|wire| slf.unpack_container(func_b, elem_ty, wire))
332360
.collect::<Result<Vec<_>, _>>()?
333361
.concat();
334-
Ok(unpacked)
362+
Ok(result)
335363
},
336364
)?;
337365

338366
Ok(outputs.collect())
339367
}
340368

341369
/// Helper function for array arguments
342-
fn array_args(&mut self, size: u64, elem_ty: &Type) -> Option<[TypeArg; 3]> {
370+
fn array_args<AK: ArrayKind>(&mut self, size: u64, elem_ty: &Type) -> Option<[TypeArg; 3]> {
343371
let row = self
344372
.type_analyzer
345-
.unpack_type(&array_type(size, elem_ty.clone()))?;
373+
.unpack_type(&AK::ty(size, elem_ty.clone()))?;
346374
let args = [
347375
size.into(),
348376
elem_ty.clone().into(),
@@ -351,15 +379,16 @@ impl BarrierOperationFactory {
351379
Some(args)
352380
}
353381

354-
/// Repack wires into an array
355-
pub fn repack_array(
382+
/// Generic array repacking using the ArrayKind trait
383+
fn repack_array<AK: ArrayKind>(
356384
&mut self,
357385
builder: &mut impl Dataflow,
358386
elem_wires: impl IntoIterator<Item = Wire>,
359387
size: u64,
360388
elem_ty: &Type,
389+
op_name: &OpName,
361390
) -> Result<Wire, BuildError> {
362-
let args = match self.array_args(size, elem_ty) {
391+
let args = match self.array_args::<AK>(size, elem_ty) {
363392
Some(args) => args,
364393
None => {
365394
return Ok(elem_wires
@@ -373,7 +402,7 @@ impl BarrierOperationFactory {
373402

374403
let mut outputs = self.apply_cached_operation(
375404
builder,
376-
&Self::ARRAY_REPACK,
405+
op_name,
377406
args.clone(),
378407
&args[..2],
379408
elem_wires,
@@ -385,8 +414,8 @@ impl BarrierOperationFactory {
385414
.chunks(inner_row_len)
386415
.map(|chunk| slf.repack_container(func_b, elem_ty, chunk.to_vec()))
387416
.collect();
388-
let array_wire = func_b.add_new_array(elem_ty.clone(), elems?)?;
389417

418+
let array_wire = func_b.add_new_generic_array::<AK>(elem_ty.clone(), elems?)?;
390419
Ok(vec![array_wire])
391420
},
392421
)?;
@@ -528,9 +557,23 @@ impl BarrierOperationFactory {
528557
if is_opt_qb(typ) {
529558
return Ok(vec![self.unpack_option(builder, container_wire)?]);
530559
}
531-
if let Some((n, elem_ty)) = typ.as_extension().and_then(array_args) {
532-
return self.unpack_array(builder, container_wire, n, elem_ty);
560+
macro_rules! handle_array_type {
561+
($array_kind:ty, $unpack_op:expr) => {
562+
if let Some((n, elem_ty)) = typ.as_extension().and_then(array_args::<$array_kind>) {
563+
return self.unpack_array::<$array_kind>(
564+
builder,
565+
container_wire,
566+
n,
567+
elem_ty,
568+
&$unpack_op,
569+
);
570+
}
571+
};
533572
}
573+
574+
handle_array_type!(Array, Self::ARRAY_UNPACK);
575+
handle_array_type!(ValueArray, Self::VARRAY_UNPACK);
576+
handle_array_type!(BorrowArray, Self::BARRAY_UNPACK);
534577
if let Some(row) = typ.as_sum().and_then(SumType::as_tuple) {
535578
let row: hugr::types::TypeRow =
536579
row.clone().try_into().expect("unexpected row variable.");
@@ -556,10 +599,24 @@ impl BarrierOperationFactory {
556599
debug_assert!(unpacked_wires.len() == 1);
557600
return self.repack_option(builder, unpacked_wires[0]);
558601
}
559-
if let Some((n, elem_ty)) = typ.as_extension().and_then(array_args) {
560-
return self.repack_array(builder, unpacked_wires, n, elem_ty);
602+
macro_rules! handle_array_type {
603+
($array_kind:ty, $repack_op:expr) => {
604+
if let Some((n, elem_ty)) = typ.as_extension().and_then(array_args::<$array_kind>) {
605+
return self.repack_array::<$array_kind>(
606+
builder,
607+
unpacked_wires,
608+
n,
609+
elem_ty,
610+
&$repack_op,
611+
);
612+
}
613+
};
561614
}
562615

616+
handle_array_type!(Array, Self::ARRAY_REPACK);
617+
handle_array_type!(ValueArray, Self::VARRAY_REPACK);
618+
handle_array_type!(BorrowArray, Self::BARRAY_REPACK);
619+
563620
if let Some(row) = typ.as_sum().and_then(SumType::as_tuple) {
564621
let row: hugr::types::TypeRow =
565622
row.clone().try_into().expect("unexpected row variable.");
@@ -586,9 +643,9 @@ pub fn build_runtime_barrier_op(array_size: u64) -> Result<Hugr, BuildError> {
586643

587644
#[cfg(test)]
588645
mod tests {
589-
use hugr::{extension::prelude::bool_t, HugrView};
590-
591646
use super::*;
647+
use hugr::{extension::prelude::bool_t, HugrView};
648+
use rstest::rstest;
592649

593650
#[test]
594651
fn test_barrier_op_factory_creation() {
@@ -610,21 +667,48 @@ mod tests {
610667
Ok(())
611668
}
612669

613-
#[test]
614-
fn test_array_unpack_repack() -> Result<(), BuildError> {
670+
#[rstest]
671+
#[case::array(
672+
Array,
673+
BarrierOperationFactory::ARRAY_UNPACK,
674+
BarrierOperationFactory::ARRAY_REPACK
675+
)]
676+
#[case::value_array(
677+
ValueArray,
678+
BarrierOperationFactory::VARRAY_UNPACK,
679+
BarrierOperationFactory::VARRAY_REPACK
680+
)]
681+
#[case::borrow_array(
682+
BorrowArray,
683+
BarrierOperationFactory::BARRAY_UNPACK,
684+
BarrierOperationFactory::BARRAY_REPACK
685+
)]
686+
fn test_array_unpack_repack<AK: ArrayKind>(
687+
#[case] _kind: AK,
688+
#[case] unpack_op: OpName,
689+
#[case] repack_op: OpName,
690+
) -> Result<(), BuildError> {
615691
let mut factory = BarrierOperationFactory::new();
616-
let array_size = 3;
617-
let array_type = array_type(array_size, qb_t());
692+
let array_size = 2;
618693

619-
let mut builder = DFGBuilder::new(Signature::new_endo(array_type))?;
694+
// Create the specific array type
695+
let array_type = AK::ty(array_size, qb_t());
620696

697+
// Build a dataflow graph that unpacks and repacks the array
698+
let mut builder = DFGBuilder::new(Signature::new_endo(array_type))?;
621699
let input = builder.input().out_wire(0);
622-
let unpacked = factory.unpack_array(&mut builder, input, array_size, &qb_t())?;
623-
assert_eq!(unpacked.len(), array_size as usize);
624700

625-
let repacked = factory.repack_array(&mut builder, unpacked, array_size, &qb_t())?;
701+
// Unpack the array
702+
let unpacked =
703+
factory.unpack_array::<AK>(&mut builder, input, array_size, &qb_t(), &unpack_op)?;
704+
705+
// Repack the array
706+
let repacked =
707+
factory.repack_array::<AK>(&mut builder, unpacked, array_size, &qb_t(), &repack_op)?;
708+
626709
let hugr = builder.finish_hugr_with_outputs([repacked])?;
627710
assert!(hugr.validate().is_ok());
711+
628712
Ok(())
629713
}
630714

0 commit comments

Comments
 (0)