Skip to content

Commit dfa5e8f

Browse files
authored
feat!: Add array_from_ptr to ArrayLowering trait (#971)
This is required to implement lowering for extension ops that return arrays. BREAKING CHANGE: `ArrayLowering` trait now requires an additional `array_from_ptr` method.
1 parent 0c64621 commit dfa5e8f

File tree

1 file changed

+88
-2
lines changed

1 file changed

+88
-2
lines changed

tket-qsystem/src/llvm/array_utils.rs

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@
33
// TODO move to hugr-llvm crate
44
// https://github.com/CQCL/tket2/issues/899
55
use anyhow::Result;
6-
use hugr::llvm::extension::collections::array::decompose_array_fat_pointer;
6+
use hugr::extension::prelude::usize_t;
7+
use hugr::llvm::emit::EmitFuncContext;
8+
use hugr::llvm::extension::collections::array::{
9+
build_array_fat_pointer, decompose_array_fat_pointer,
10+
};
711
use hugr::llvm::extension::collections::{array, stack_array};
12+
use hugr::llvm::inkwell::types::{BasicType, BasicTypeEnum};
813
use hugr::llvm::inkwell::values::BasicValueEnum;
914
use hugr::llvm::{inkwell, CodegenExtension};
15+
use hugr::{HugrView, Node};
1016
use inkwell::builder::{Builder, BuilderError};
1117
use inkwell::context::Context;
1218
use inkwell::types::{IntType, PointerType, StructType};
@@ -27,6 +33,15 @@ pub trait ArrayLowering {
2733
builder: &Builder<'c>,
2834
val: BasicValueEnum<'c>,
2935
) -> Result<PointerValue<'c>>;
36+
37+
/// Turns a pointer to the first array element into an array value in the given lowering.
38+
fn array_from_ptr<'c, H: HugrView<Node = Node>>(
39+
&self,
40+
ctx: &mut EmitFuncContext<'c, '_, H>,
41+
ptr: PointerValue<'c>,
42+
elem_type: BasicTypeEnum<'c>,
43+
length: u32,
44+
) -> Result<BasicValueEnum<'c>>;
3045
}
3146

3247
/// Array lowering via the stack as implemented in [stack_array].
@@ -61,6 +76,27 @@ impl<ACG: stack_array::ArrayCodegen + Clone> ArrayLowering for StackArrayLowerin
6176
let (elem_ptr, _) = build_array_alloca(builder, val.into_array_value())?;
6277
Ok(elem_ptr)
6378
}
79+
80+
fn array_from_ptr<'c, H: HugrView<Node = Node>>(
81+
&self,
82+
ctx: &mut EmitFuncContext<'c, '_, H>,
83+
ptr: PointerValue<'c>,
84+
elem_type: BasicTypeEnum<'c>,
85+
length: u32,
86+
) -> Result<BasicValueEnum<'c>> {
87+
let builder = ctx.builder();
88+
let ptr = builder
89+
.build_bit_cast(
90+
ptr,
91+
elem_type
92+
.array_type(length)
93+
.ptr_type(AddressSpace::default()),
94+
"",
95+
)?
96+
.into_pointer_value();
97+
let array = builder.build_load(ptr, "")?.into_array_value();
98+
Ok(array.into())
99+
}
64100
}
65101

66102
/// Array lowering via a heap as implemented in [mod@array].
@@ -92,6 +128,23 @@ impl<ACG: array::ArrayCodegen + Clone> ArrayLowering for HeapArrayLowering<ACG>
92128
let elem_ptr = unsafe { builder.build_in_bounds_gep(array_ptr, &[offset], "")? };
93129
Ok(elem_ptr)
94130
}
131+
132+
fn array_from_ptr<'c, H: HugrView<Node = Node>>(
133+
&self,
134+
ctx: &mut EmitFuncContext<'c, '_, H>,
135+
ptr: PointerValue<'c>,
136+
_elem_type: BasicTypeEnum<'c>,
137+
_length: u32,
138+
) -> Result<BasicValueEnum<'c>> {
139+
let usize_ty = ctx
140+
.typing_session()
141+
.llvm_type(&usize_t())
142+
.expect("Prelude codegen is registered")
143+
.into_int_type();
144+
let offset = usize_ty.const_zero();
145+
let array = build_array_fat_pointer(ctx, ptr, offset)?;
146+
Ok(array.into())
147+
}
95148
}
96149

97150
/// Helper function to allocate an array on the stack.
@@ -231,7 +284,8 @@ pub fn struct_1d_arr_alloc<'a>(
231284
#[cfg(test)]
232285
mod tests {
233286
use super::*;
234-
use hugr::llvm::inkwell::context::Context;
287+
use hugr::llvm::{inkwell::context::Context, test::llvm_ctx};
288+
use rstest::rstest;
235289

236290
/// Test that build_array_alloca properly allocates an array.
237291
#[test]
@@ -377,4 +431,36 @@ mod tests {
377431
// Verify the generated code is valid
378432
assert!(module.verify().is_ok(), "Module verification failed");
379433
}
434+
435+
/// Tests that [ArrayLowering::array_to_ptr] and [ArrayLowering::array_from_ptr] are inverses.
436+
#[rstest]
437+
#[case(DEFAULT_HEAP_ARRAY_LOWERING)]
438+
#[case(DEFAULT_STACK_ARRAY_LOWERING)]
439+
fn test_array_ptr_conversion(#[case] array_lowering: impl ArrayLowering) {
440+
let mut llvm_ctx = llvm_ctx(-1);
441+
llvm_ctx.add_extensions(|cge| cge.add_default_prelude_extensions());
442+
443+
let mod_ctx = llvm_ctx.get_emit_module_context();
444+
let function_type = mod_ctx.iw_context().void_type().fn_type(&[], false);
445+
let function = mod_ctx
446+
.module()
447+
.add_function("test_function", function_type, None);
448+
let mut emit_ctx = EmitFuncContext::new(mod_ctx, function).unwrap();
449+
450+
let elem_ty = emit_ctx.iw_context().i32_type().into();
451+
let size = 2;
452+
453+
let (array_ptr, _) = build_array(emit_ctx.iw_context(), emit_ctx.builder()).unwrap();
454+
let array = array_lowering
455+
.array_from_ptr(&mut emit_ctx, array_ptr, elem_ty, size)
456+
.unwrap();
457+
let new_array_ptr = array_lowering
458+
.array_to_ptr(emit_ctx.builder(), array)
459+
.unwrap();
460+
assert_eq!(array_ptr.get_type(), new_array_ptr.get_type());
461+
let new_array = array_lowering
462+
.array_from_ptr(&mut emit_ctx, new_array_ptr, elem_ty, size)
463+
.unwrap();
464+
assert_eq!(array.get_type(), new_array.get_type());
465+
}
380466
}

0 commit comments

Comments
 (0)