@@ -9,7 +9,11 @@ use hugr::{
9
9
Extension ,
10
10
} ,
11
11
ops:: { DataflowOpTrait , ExtensionOp , OpName } ,
12
- std_extensions:: collections:: array:: { array_type, array_type_parametric, ArrayOpBuilder } ,
12
+ std_extensions:: collections:: {
13
+ array:: { array_type, op_builder:: GenericArrayOpBuilder , Array , ArrayKind } ,
14
+ borrow_array:: BorrowArray ,
15
+ value_array:: ValueArray ,
16
+ } ,
13
17
types:: {
14
18
type_param:: TypeParam , FuncValueType , PolyFuncTypeRV , Signature , SumType , Type , TypeArg ,
15
19
TypeBound , TypeRV ,
@@ -62,6 +66,43 @@ pub struct BarrierOperationFactory {
62
66
type_analyzer : QTypeAnalyzer ,
63
67
}
64
68
69
+ fn generic_array_unpack_sig < AK : ArrayKind > ( ) -> PolyFuncTypeRV {
70
+ PolyFuncTypeRV :: new (
71
+ vec ! [
72
+ TypeParam :: max_nat_type( ) ,
73
+ TypeParam :: RuntimeType ( TypeBound :: Linear ) ,
74
+ TypeParam :: new_list_type( TypeBound :: Linear ) ,
75
+ ] ,
76
+ FuncValueType :: new (
77
+ AK :: ty_parametric (
78
+ TypeArg :: new_var_use ( 0 , TypeParam :: max_nat_type ( ) ) ,
79
+ Type :: new_var_use ( 1 , TypeBound :: Linear ) ,
80
+ )
81
+ . unwrap ( ) ,
82
+ TypeRV :: new_row_var_use ( 2 , TypeBound :: Linear ) ,
83
+ ) ,
84
+ )
85
+ }
86
+
87
+ /// Helper function to add array operations for any ArrayKind to the extension
88
+ fn add_array_ops < AK : ArrayKind > (
89
+ ext : & mut Extension ,
90
+ ext_ref : & std:: sync:: Weak < Extension > ,
91
+ unpack_name : OpName ,
92
+ repack_name : OpName ,
93
+ ) -> Result < ( ) , hugr:: extension:: ExtensionBuildError > {
94
+ let array_unpack_sig = generic_array_unpack_sig :: < AK > ( ) ;
95
+ // pack some wires into an array
96
+ ext. add_op (
97
+ repack_name,
98
+ Default :: default ( ) ,
99
+ invert_sig ( & array_unpack_sig) ,
100
+ ext_ref,
101
+ ) ?;
102
+ // unpack an array into some wires
103
+ ext. add_op ( unpack_name, Default :: default ( ) , array_unpack_sig, ext_ref) ?;
104
+ Ok ( ( ) )
105
+ }
65
106
impl BarrierOperationFactory {
66
107
/// Temporary extension name.
67
108
pub ( super ) const TEMP_EXT_NAME : hugr:: hugr:: IdentList =
@@ -72,6 +113,10 @@ impl BarrierOperationFactory {
72
113
pub ( super ) const WRAPPED_BARRIER : OpName = OpName :: new_static ( "wrapped_barrier" ) ;
73
114
pub ( super ) const ARRAY_UNPACK : OpName = OpName :: new_static ( "array_unpack" ) ;
74
115
pub ( super ) const ARRAY_REPACK : OpName = OpName :: new_static ( "array_repack" ) ;
116
+ pub ( super ) const VARRAY_UNPACK : OpName = OpName :: new_static ( "varray_unpack" ) ;
117
+ pub ( super ) const VARRAY_REPACK : OpName = OpName :: new_static ( "varray_repack" ) ;
118
+ pub ( super ) const BARRAY_UNPACK : OpName = OpName :: new_static ( "barray_unpack" ) ;
119
+ pub ( super ) const BARRAY_REPACK : OpName = OpName :: new_static ( "barray_repack" ) ;
75
120
pub ( super ) const TUPLE_UNPACK : OpName = OpName :: new_static ( "tuple_unpack" ) ;
76
121
pub ( super ) const TUPLE_REPACK : OpName = OpName :: new_static ( "tuple_repack" ) ;
77
122
@@ -122,35 +167,17 @@ impl BarrierOperationFactory {
122
167
ext_ref,
123
168
)
124
169
. unwrap ( ) ;
125
- let array_unpack_sig = PolyFuncTypeRV :: new (
126
- vec ! [
127
- TypeParam :: max_nat_type( ) ,
128
- TypeParam :: RuntimeType ( TypeBound :: Linear ) ,
129
- TypeParam :: new_list_type( TypeBound :: Linear ) ,
130
- ] ,
131
- FuncValueType :: new (
132
- array_type_parametric (
133
- TypeArg :: new_var_use ( 0 , TypeParam :: max_nat_type ( ) ) ,
134
- Type :: new_var_use ( 1 , TypeBound :: Linear ) ,
135
- )
136
- . unwrap ( ) ,
137
- TypeRV :: new_row_var_use ( 2 , TypeBound :: Linear ) ,
138
- ) ,
139
- ) ;
140
- // pack some wires into an array
141
- ext. add_op (
142
- Self :: ARRAY_REPACK ,
143
- Default :: default ( ) ,
144
- invert_sig ( & array_unpack_sig) ,
145
- ext_ref,
146
- )
147
- . unwrap ( ) ;
148
- // unpack an array into some wires
149
- ext. add_op (
150
- Self :: ARRAY_UNPACK ,
151
- Default :: default ( ) ,
152
- array_unpack_sig,
170
+
171
+ // Add array operations for all ArrayKind types
172
+ add_array_ops :: < Array > ( ext, ext_ref, Self :: ARRAY_UNPACK , Self :: ARRAY_REPACK )
173
+ . unwrap ( ) ;
174
+ add_array_ops :: < ValueArray > ( ext, ext_ref, Self :: VARRAY_UNPACK , Self :: VARRAY_REPACK )
175
+ . unwrap ( ) ;
176
+ add_array_ops :: < BorrowArray > (
177
+ ext,
153
178
ext_ref,
179
+ Self :: BARRAY_UNPACK ,
180
+ Self :: BARRAY_REPACK ,
154
181
)
155
182
. unwrap ( ) ;
156
183
@@ -303,46 +330,47 @@ impl BarrierOperationFactory {
303
330
)
304
331
}
305
332
306
- /// Unpack an array into individual wires
307
- pub fn unpack_array (
333
+ /// Generic array unpacking using the ArrayKind trait
334
+ fn unpack_array < AK : ArrayKind > (
308
335
& mut self ,
309
336
builder : & mut impl Dataflow ,
310
337
array_wire : Wire ,
311
338
size : u64 ,
312
339
elem_ty : & Type ,
340
+ op_name : & OpName ,
313
341
) -> Result < Vec < Wire > , BuildError > {
314
- // Calculate arguments for the array unpacking
315
- let args = match self . array_args ( size, elem_ty) {
342
+ let args = match self . array_args :: < AK > ( size, elem_ty) {
316
343
Some ( args) => args,
317
344
None => return Ok ( vec ! [ array_wire] ) , // Not a qubit-containing array
318
345
} ;
319
346
320
347
let outputs = self . apply_cached_operation (
321
348
builder,
322
- & Self :: ARRAY_UNPACK ,
349
+ op_name ,
323
350
args. clone ( ) ,
324
351
& args[ ..2 ] ,
325
352
[ array_wire] ,
326
353
|slf, func_b| {
327
354
let w = func_b. input ( ) . out_wire ( 0 ) ;
328
- let elems = func_b. add_array_unpack ( elem_ty. clone ( ) , size, w) ?;
329
- let unpacked: Vec < _ > = elems
355
+ let elems = func_b. add_generic_array_unpack :: < AK > ( elem_ty. clone ( ) , size, w) ?;
356
+
357
+ let result: Vec < _ > = elems
330
358
. into_iter ( )
331
359
. map ( |wire| slf. unpack_container ( func_b, elem_ty, wire) )
332
360
. collect :: < Result < Vec < _ > , _ > > ( ) ?
333
361
. concat ( ) ;
334
- Ok ( unpacked )
362
+ Ok ( result )
335
363
} ,
336
364
) ?;
337
365
338
366
Ok ( outputs. collect ( ) )
339
367
}
340
368
341
369
/// Helper function for array arguments
342
- fn array_args ( & mut self , size : u64 , elem_ty : & Type ) -> Option < [ TypeArg ; 3 ] > {
370
+ fn array_args < AK : ArrayKind > ( & mut self , size : u64 , elem_ty : & Type ) -> Option < [ TypeArg ; 3 ] > {
343
371
let row = self
344
372
. type_analyzer
345
- . unpack_type ( & array_type ( size, elem_ty. clone ( ) ) ) ?;
373
+ . unpack_type ( & AK :: ty ( size, elem_ty. clone ( ) ) ) ?;
346
374
let args = [
347
375
size. into ( ) ,
348
376
elem_ty. clone ( ) . into ( ) ,
@@ -351,15 +379,16 @@ impl BarrierOperationFactory {
351
379
Some ( args)
352
380
}
353
381
354
- /// Repack wires into an array
355
- pub fn repack_array (
382
+ /// Generic array repacking using the ArrayKind trait
383
+ fn repack_array < AK : ArrayKind > (
356
384
& mut self ,
357
385
builder : & mut impl Dataflow ,
358
386
elem_wires : impl IntoIterator < Item = Wire > ,
359
387
size : u64 ,
360
388
elem_ty : & Type ,
389
+ op_name : & OpName ,
361
390
) -> Result < Wire , BuildError > {
362
- let args = match self . array_args ( size, elem_ty) {
391
+ let args = match self . array_args :: < AK > ( size, elem_ty) {
363
392
Some ( args) => args,
364
393
None => {
365
394
return Ok ( elem_wires
@@ -373,7 +402,7 @@ impl BarrierOperationFactory {
373
402
374
403
let mut outputs = self . apply_cached_operation (
375
404
builder,
376
- & Self :: ARRAY_REPACK ,
405
+ op_name ,
377
406
args. clone ( ) ,
378
407
& args[ ..2 ] ,
379
408
elem_wires,
@@ -385,8 +414,8 @@ impl BarrierOperationFactory {
385
414
. chunks ( inner_row_len)
386
415
. map ( |chunk| slf. repack_container ( func_b, elem_ty, chunk. to_vec ( ) ) )
387
416
. collect ( ) ;
388
- let array_wire = func_b. add_new_array ( elem_ty. clone ( ) , elems?) ?;
389
417
418
+ let array_wire = func_b. add_new_generic_array :: < AK > ( elem_ty. clone ( ) , elems?) ?;
390
419
Ok ( vec ! [ array_wire] )
391
420
} ,
392
421
) ?;
@@ -528,9 +557,23 @@ impl BarrierOperationFactory {
528
557
if is_opt_qb ( typ) {
529
558
return Ok ( vec ! [ self . unpack_option( builder, container_wire) ?] ) ;
530
559
}
531
- if let Some ( ( n, elem_ty) ) = typ. as_extension ( ) . and_then ( array_args) {
532
- return self . unpack_array ( builder, container_wire, n, elem_ty) ;
560
+ macro_rules! handle_array_type {
561
+ ( $array_kind: ty, $unpack_op: expr) => {
562
+ if let Some ( ( n, elem_ty) ) = typ. as_extension( ) . and_then( array_args:: <$array_kind>) {
563
+ return self . unpack_array:: <$array_kind>(
564
+ builder,
565
+ container_wire,
566
+ n,
567
+ elem_ty,
568
+ & $unpack_op,
569
+ ) ;
570
+ }
571
+ } ;
533
572
}
573
+
574
+ handle_array_type ! ( Array , Self :: ARRAY_UNPACK ) ;
575
+ handle_array_type ! ( ValueArray , Self :: VARRAY_UNPACK ) ;
576
+ handle_array_type ! ( BorrowArray , Self :: BARRAY_UNPACK ) ;
534
577
if let Some ( row) = typ. as_sum ( ) . and_then ( SumType :: as_tuple) {
535
578
let row: hugr:: types:: TypeRow =
536
579
row. clone ( ) . try_into ( ) . expect ( "unexpected row variable." ) ;
@@ -556,10 +599,24 @@ impl BarrierOperationFactory {
556
599
debug_assert ! ( unpacked_wires. len( ) == 1 ) ;
557
600
return self . repack_option ( builder, unpacked_wires[ 0 ] ) ;
558
601
}
559
- if let Some ( ( n, elem_ty) ) = typ. as_extension ( ) . and_then ( array_args) {
560
- return self . repack_array ( builder, unpacked_wires, n, elem_ty) ;
602
+ macro_rules! handle_array_type {
603
+ ( $array_kind: ty, $repack_op: expr) => {
604
+ if let Some ( ( n, elem_ty) ) = typ. as_extension( ) . and_then( array_args:: <$array_kind>) {
605
+ return self . repack_array:: <$array_kind>(
606
+ builder,
607
+ unpacked_wires,
608
+ n,
609
+ elem_ty,
610
+ & $repack_op,
611
+ ) ;
612
+ }
613
+ } ;
561
614
}
562
615
616
+ handle_array_type ! ( Array , Self :: ARRAY_REPACK ) ;
617
+ handle_array_type ! ( ValueArray , Self :: VARRAY_REPACK ) ;
618
+ handle_array_type ! ( BorrowArray , Self :: BARRAY_REPACK ) ;
619
+
563
620
if let Some ( row) = typ. as_sum ( ) . and_then ( SumType :: as_tuple) {
564
621
let row: hugr:: types:: TypeRow =
565
622
row. clone ( ) . try_into ( ) . expect ( "unexpected row variable." ) ;
@@ -586,9 +643,9 @@ pub fn build_runtime_barrier_op(array_size: u64) -> Result<Hugr, BuildError> {
586
643
587
644
#[ cfg( test) ]
588
645
mod tests {
589
- use hugr:: { extension:: prelude:: bool_t, HugrView } ;
590
-
591
646
use super :: * ;
647
+ use hugr:: { extension:: prelude:: bool_t, HugrView } ;
648
+ use rstest:: rstest;
592
649
593
650
#[ test]
594
651
fn test_barrier_op_factory_creation ( ) {
@@ -610,21 +667,48 @@ mod tests {
610
667
Ok ( ( ) )
611
668
}
612
669
613
- #[ test]
614
- fn test_array_unpack_repack ( ) -> Result < ( ) , BuildError > {
670
+ #[ rstest]
671
+ #[ case:: array(
672
+ Array ,
673
+ BarrierOperationFactory :: ARRAY_UNPACK ,
674
+ BarrierOperationFactory :: ARRAY_REPACK
675
+ ) ]
676
+ #[ case:: value_array(
677
+ ValueArray ,
678
+ BarrierOperationFactory :: VARRAY_UNPACK ,
679
+ BarrierOperationFactory :: VARRAY_REPACK
680
+ ) ]
681
+ #[ case:: borrow_array(
682
+ BorrowArray ,
683
+ BarrierOperationFactory :: BARRAY_UNPACK ,
684
+ BarrierOperationFactory :: BARRAY_REPACK
685
+ ) ]
686
+ fn test_array_unpack_repack < AK : ArrayKind > (
687
+ #[ case] _kind : AK ,
688
+ #[ case] unpack_op : OpName ,
689
+ #[ case] repack_op : OpName ,
690
+ ) -> Result < ( ) , BuildError > {
615
691
let mut factory = BarrierOperationFactory :: new ( ) ;
616
- let array_size = 3 ;
617
- let array_type = array_type ( array_size, qb_t ( ) ) ;
692
+ let array_size = 2 ;
618
693
619
- let mut builder = DFGBuilder :: new ( Signature :: new_endo ( array_type) ) ?;
694
+ // Create the specific array type
695
+ let array_type = AK :: ty ( array_size, qb_t ( ) ) ;
620
696
697
+ // Build a dataflow graph that unpacks and repacks the array
698
+ let mut builder = DFGBuilder :: new ( Signature :: new_endo ( array_type) ) ?;
621
699
let input = builder. input ( ) . out_wire ( 0 ) ;
622
- let unpacked = factory. unpack_array ( & mut builder, input, array_size, & qb_t ( ) ) ?;
623
- assert_eq ! ( unpacked. len( ) , array_size as usize ) ;
624
700
625
- let repacked = factory. repack_array ( & mut builder, unpacked, array_size, & qb_t ( ) ) ?;
701
+ // Unpack the array
702
+ let unpacked =
703
+ factory. unpack_array :: < AK > ( & mut builder, input, array_size, & qb_t ( ) , & unpack_op) ?;
704
+
705
+ // Repack the array
706
+ let repacked =
707
+ factory. repack_array :: < AK > ( & mut builder, unpacked, array_size, & qb_t ( ) , & repack_op) ?;
708
+
626
709
let hugr = builder. finish_hugr_with_outputs ( [ repacked] ) ?;
627
710
assert ! ( hugr. validate( ) . is_ok( ) ) ;
711
+
628
712
Ok ( ( ) )
629
713
}
630
714
0 commit comments