Skip to content

Commit 4bde6d9

Browse files
authored
fix!: Change array result ops signature to return array result (#888)
Needed for CQCL/guppylang#981 BREAKING CHANGE: `ResultOpDef::ArrBool`, `ResultOpDef::ArrInt`, `ResultOpDef::ArrUInt` and `ResultOpDef::ArrF64` signatures now return array results
1 parent 724b8bd commit 4bde6d9

11 files changed

+212
-38
lines changed

tket2-exts/src/tket2_exts/data/tket2/result.json

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,32 @@
4444
"bound": "A"
4545
}
4646
],
47-
"output": []
47+
"output": [
48+
{
49+
"t": "Opaque",
50+
"extension": "collections.array",
51+
"id": "array",
52+
"args": [
53+
{
54+
"tya": "Variable",
55+
"idx": 1,
56+
"cached_decl": {
57+
"tp": "BoundedNat",
58+
"bound": null
59+
}
60+
},
61+
{
62+
"tya": "Type",
63+
"ty": {
64+
"t": "Sum",
65+
"s": "Unit",
66+
"size": 2
67+
}
68+
}
69+
],
70+
"bound": "A"
71+
}
72+
]
4873
}
4974
},
5075
"binary": false
@@ -92,7 +117,34 @@
92117
"bound": "A"
93118
}
94119
],
95-
"output": []
120+
"output": [
121+
{
122+
"t": "Opaque",
123+
"extension": "collections.array",
124+
"id": "array",
125+
"args": [
126+
{
127+
"tya": "Variable",
128+
"idx": 1,
129+
"cached_decl": {
130+
"tp": "BoundedNat",
131+
"bound": null
132+
}
133+
},
134+
{
135+
"tya": "Type",
136+
"ty": {
137+
"t": "Opaque",
138+
"extension": "arithmetic.float.types",
139+
"id": "float64",
140+
"args": [],
141+
"bound": "C"
142+
}
143+
}
144+
],
145+
"bound": "A"
146+
}
147+
]
96148
}
97149
},
98150
"binary": false
@@ -153,7 +205,43 @@
153205
"bound": "A"
154206
}
155207
],
156-
"output": []
208+
"output": [
209+
{
210+
"t": "Opaque",
211+
"extension": "collections.array",
212+
"id": "array",
213+
"args": [
214+
{
215+
"tya": "Variable",
216+
"idx": 1,
217+
"cached_decl": {
218+
"tp": "BoundedNat",
219+
"bound": null
220+
}
221+
},
222+
{
223+
"tya": "Type",
224+
"ty": {
225+
"t": "Opaque",
226+
"extension": "arithmetic.int.types",
227+
"id": "int",
228+
"args": [
229+
{
230+
"tya": "Variable",
231+
"idx": 2,
232+
"cached_decl": {
233+
"tp": "BoundedNat",
234+
"bound": 7
235+
}
236+
}
237+
],
238+
"bound": "C"
239+
}
240+
}
241+
],
242+
"bound": "A"
243+
}
244+
]
157245
}
158246
},
159247
"binary": false
@@ -214,7 +302,43 @@
214302
"bound": "A"
215303
}
216304
],
217-
"output": []
305+
"output": [
306+
{
307+
"t": "Opaque",
308+
"extension": "collections.array",
309+
"id": "array",
310+
"args": [
311+
{
312+
"tya": "Variable",
313+
"idx": 1,
314+
"cached_decl": {
315+
"tp": "BoundedNat",
316+
"bound": null
317+
}
318+
},
319+
{
320+
"tya": "Type",
321+
"ty": {
322+
"t": "Opaque",
323+
"extension": "arithmetic.int.types",
324+
"id": "int",
325+
"args": [
326+
{
327+
"tya": "Variable",
328+
"idx": 2,
329+
"cached_decl": {
330+
"tp": "BoundedNat",
331+
"bound": 7
332+
}
333+
}
334+
],
335+
"bound": "C"
336+
}
337+
}
338+
],
339+
"bound": "A"
340+
}
341+
]
218342
}
219343
},
220344
"binary": false

tket2-hseries/src/extension/result.rs

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,20 @@ impl ResultOpDef {
161161
}
162162

163163
fn result_signature(&self) -> SignatureFunc {
164-
PolyFuncType::new(
165-
[vec![TypeParam::String], self.type_params()].concat(),
166-
Signature::new(self.arg_type(), type_row![]),
164+
let sig = if self.is_array_result_op() {
165+
// Do not consume input arrays to allow them to be not copyable.
166+
Signature::new(self.arg_type(), self.arg_type())
167+
} else {
168+
Signature::new(self.arg_type(), type_row![])
169+
};
170+
PolyFuncType::new([vec![TypeParam::String], self.type_params()].concat(), sig).into()
171+
}
172+
173+
fn is_array_result_op(&self) -> bool {
174+
matches!(
175+
self,
176+
Self::ArrBool | Self::ArrF64 | Self::ArrInt | Self::ArrUInt
167177
)
168-
.into()
169178
}
170179
}
171180

@@ -390,19 +399,32 @@ impl TryFrom<&OpType> for ResultOpDef {
390399
pub trait ResultOpBuilder: Dataflow {
391400
/// Add a "tket2.result" op.
392401
fn add_result(&mut self, result_wire: Wire, op: ResultOp) -> Result<(), BuildError> {
402+
debug_assert!(!op.result_op.is_array_result_op());
393403
let handle = self.add_dataflow_op(op, [result_wire])?;
394404

395405
debug_assert_eq!(handle.outputs().len(), 0);
396406
Ok(())
397407
}
408+
409+
/// Add a "tket2.result" op with array result type.
410+
fn add_array_result(
411+
&mut self,
412+
result_wire: Wire,
413+
op: ResultOp,
414+
) -> Result<[Wire; 1], BuildError> {
415+
debug_assert!(op.result_op.is_array_result_op());
416+
Ok(self
417+
.add_dataflow_op(op.clone(), [result_wire])?
418+
.outputs_arr())
419+
}
398420
}
399421

400422
impl<D: Dataflow> ResultOpBuilder for D {}
401423

402424
#[cfg(test)]
403425
pub(crate) mod test {
404426
use cool_asserts::assert_matches;
405-
use hugr::types::Signature;
427+
use hugr::types::{NoRV, Signature, TypeBase};
406428
use hugr::HugrView;
407429
use hugr::{
408430
builder::{Dataflow, DataflowHugr, FunctionBuilder},
@@ -436,17 +458,15 @@ pub(crate) mod test {
436458
INT_TYPES[5].clone(),
437459
INT_TYPES[6].clone(),
438460
];
439-
let in_row = [
440-
in_row.clone(),
441-
in_row
442-
.into_iter()
443-
.map(|t| array_type(ARR_SIZE, t))
444-
.collect(),
445-
]
446-
.concat();
461+
let arrs: Vec<TypeBase<NoRV>> = in_row
462+
.clone()
463+
.into_iter()
464+
.map(|t| array_type(ARR_SIZE, t))
465+
.collect();
466+
let in_row = [in_row.clone(), arrs.clone()].concat();
447467
let hugr = {
448468
let mut func_builder =
449-
FunctionBuilder::new("circuit", Signature::new(in_row, type_row![])).unwrap();
469+
FunctionBuilder::new("circuit", Signature::new(in_row, arrs)).unwrap();
450470
let ops = [
451471
ResultOp::new_bool("b"),
452472
ResultOp::new_f64("f"),
@@ -474,13 +494,15 @@ pub(crate) mod test {
474494
for (w, op) in [b, f, i, u].iter().zip(ops.iter()) {
475495
func_builder.add_result(*w, op.clone()).unwrap();
476496
}
497+
let mut outputs = Vec::new();
477498
for (w, op) in [a_b, a_f, a_i, a_u].iter().zip(ops.iter()) {
478-
func_builder
479-
.add_result(*w, op.clone().array_op(ARR_SIZE))
499+
let [out_w] = func_builder
500+
.add_array_result(*w, op.clone().array_op(ARR_SIZE))
480501
.unwrap();
502+
outputs.push(out_w);
481503
}
482504

483-
func_builder.finish_hugr_with_outputs([]).unwrap()
505+
func_builder.finish_hugr_with_outputs(outputs).unwrap()
484506
};
485507
assert_matches!(hugr.validate(), Ok(_));
486508
}

tket2-hseries/src/llvm/result.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ impl<'c, H: HugrView<Node = Node>> ResultEmitter<'c, '_, '_, H> {
154154
tag_ptr: BasicValueEnum,
155155
tag_len: IntValue,
156156
) -> Result<()> {
157-
// TODO update to return array after https://github.com/CQCL/tket2/pull/888
158157
let ResultArgs::Array(_, length) = op.args else {
159158
bail!("Expected array argument")
160159
};
@@ -241,6 +240,8 @@ impl<'c, H: HugrView<Node = Node>> ResultEmitter<'c, '_, '_, H> {
241240
.try_into()
242241
.map_err(|_| anyhow!("result_arr_bool expects one input"))?;
243242
self.build_print_array_call(val, op, &ElemType::Bool, tag_ptr, tag_len)?;
243+
// Array results need to output the input value again.
244+
return args.outputs.finish(self.builder(), [val]);
244245
}
245246
ResultOpDef::ArrInt => {
246247
let (tag_ptr, tag_len) = self.generate_global_tag(&args, "INTARR:").unwrap();
@@ -249,6 +250,7 @@ impl<'c, H: HugrView<Node = Node>> ResultEmitter<'c, '_, '_, H> {
249250
.try_into()
250251
.map_err(|_| anyhow!("result_arr_int expects one input"))?;
251252
self.build_print_array_call(val, op, &ElemType::Int, tag_ptr, tag_len)?;
253+
return args.outputs.finish(self.builder(), [val]);
252254
}
253255
ResultOpDef::ArrUInt => {
254256
let (tag_ptr, tag_len) = self.generate_global_tag(&args, "INTARR:").unwrap();
@@ -257,6 +259,7 @@ impl<'c, H: HugrView<Node = Node>> ResultEmitter<'c, '_, '_, H> {
257259
.try_into()
258260
.map_err(|_| anyhow!("result_arr_uint expects one input"))?;
259261
self.build_print_array_call(val, op, &ElemType::Uint, tag_ptr, tag_len)?;
262+
return args.outputs.finish(self.builder(), [val]);
260263
}
261264
ResultOpDef::ArrF64 => {
262265
let (tag_ptr, tag_len) = self.generate_global_tag(&args, "FLOATARR:").unwrap();
@@ -265,6 +268,7 @@ impl<'c, H: HugrView<Node = Node>> ResultEmitter<'c, '_, '_, H> {
265268
.try_into()
266269
.map_err(|_| anyhow!("result_arr_float expects one input"))?;
267270
self.build_print_array_call(val, op, &ElemType::Float, tag_ptr, tag_len)?;
271+
return args.outputs.finish(self.builder(), [val]);
268272
}
269273
}
270274
args.outputs.finish(self.builder(), [])

tket2-hseries/src/llvm/snapshots/tket2_hseries__llvm__result__test__emit_result_codegen@llvm14_5.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ source_filename = "test_context"
77

88
@res_test_arr_b.866EEC87.0 = private constant [27 x i8] c"\1AUSER:BOOLARR:test_arr_bool"
99

10-
define void @_hl.main.1([10 x i1] %0) {
10+
define [10 x i1] @_hl.main.1([10 x i1] %0) {
1111
alloca_block:
1212
br label %entry_block
1313

@@ -31,7 +31,7 @@ entry_block: ; preds = %alloca_block
3131
store i1* %3, i1** %mask_ptr, align 8
3232
%5 = load <{ i32, i32, i1*, i1* }>, <{ i32, i32, i1*, i1* }>* %out_arr_alloca, align 1
3333
call void @print_bool_arr(i8* getelementptr inbounds ([27 x i8], [27 x i8]* @res_test_arr_b.866EEC87.0, i32 0, i32 0), i64 %tag_len2, <{ i32, i32, i1*, i1* }>* %out_arr_alloca)
34-
ret void
34+
ret [10 x i1] %0
3535
}
3636

3737
declare void @print_bool_arr(i8*, i64, <{ i32, i32, i1*, i1* }>*)

tket2-hseries/src/llvm/snapshots/tket2_hseries__llvm__result__test__emit_result_codegen@llvm14_6.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ source_filename = "test_context"
77

88
@res_test_arr_i.DFD30452.0 = private constant [25 x i8] c"\18USER:INTARR:test_arr_int"
99

10-
define void @_hl.main.1([10 x i64] %0) {
10+
define [10 x i64] @_hl.main.1([10 x i64] %0) {
1111
alloca_block:
1212
br label %entry_block
1313

@@ -31,7 +31,7 @@ entry_block: ; preds = %alloca_block
3131
store i1* %3, i1** %mask_ptr, align 8
3232
%5 = load <{ i32, i32, i64*, i1* }>, <{ i32, i32, i64*, i1* }>* %out_arr_alloca, align 1
3333
call void @print_int_arr(i8* getelementptr inbounds ([25 x i8], [25 x i8]* @res_test_arr_i.DFD30452.0, i32 0, i32 0), i64 %tag_len2, <{ i32, i32, i64*, i1* }>* %out_arr_alloca)
34-
ret void
34+
ret [10 x i64] %0
3535
}
3636

3737
declare void @print_int_arr(i8*, i64, <{ i32, i32, i64*, i1* }>*)

tket2-hseries/src/llvm/snapshots/tket2_hseries__llvm__result__test__emit_result_codegen@llvm14_7.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ source_filename = "test_context"
77

88
@res_test_arr_u.3D1C515C.0 = private constant [26 x i8] c"\19USER:INTARR:test_arr_uint"
99

10-
define void @_hl.main.1([10 x i64] %0) {
10+
define [10 x i64] @_hl.main.1([10 x i64] %0) {
1111
alloca_block:
1212
br label %entry_block
1313

@@ -31,7 +31,7 @@ entry_block: ; preds = %alloca_block
3131
store i1* %3, i1** %mask_ptr, align 8
3232
%5 = load <{ i32, i32, i64*, i1* }>, <{ i32, i32, i64*, i1* }>* %out_arr_alloca, align 1
3333
call void @print_uint_arr(i8* getelementptr inbounds ([26 x i8], [26 x i8]* @res_test_arr_u.3D1C515C.0, i32 0, i32 0), i64 %tag_len2, <{ i32, i32, i64*, i1* }>* %out_arr_alloca)
34-
ret void
34+
ret [10 x i64] %0
3535
}
3636

3737
declare void @print_uint_arr(i8*, i64, <{ i32, i32, i64*, i1* }>*)

tket2-hseries/src/llvm/snapshots/tket2_hseries__llvm__result__test__emit_result_codegen@llvm14_8.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ source_filename = "test_context"
77

88
@res_test_arr_f.038B27BE.0 = private constant [27 x i8] c"\1AUSER:FLOATARR:test_arr_f64"
99

10-
define void @_hl.main.1([10 x double] %0) {
10+
define [10 x double] @_hl.main.1([10 x double] %0) {
1111
alloca_block:
1212
br label %entry_block
1313

@@ -31,7 +31,7 @@ entry_block: ; preds = %alloca_block
3131
store i1* %3, i1** %mask_ptr, align 8
3232
%5 = load <{ i32, i32, double*, i1* }>, <{ i32, i32, double*, i1* }>* %out_arr_alloca, align 1
3333
call void @print_float_arr(i8* getelementptr inbounds ([27 x i8], [27 x i8]* @res_test_arr_f.038B27BE.0, i32 0, i32 0), i64 %tag_len2, <{ i32, i32, double*, i1* }>* %out_arr_alloca)
34-
ret void
34+
ret [10 x double] %0
3535
}
3636

3737
declare void @print_float_arr(i8*, i64, <{ i32, i32, double*, i1* }>*)

0 commit comments

Comments
 (0)