Skip to content

Commit 1bb86fa

Browse files
authored
feat: Add llvm codegen for tket2.bool (#950)
Copied over from Guppy, closes #909
1 parent 44c1f91 commit 1bb86fa

16 files changed

+431
-0
lines changed

tket2/src/llvm.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
//! `hugr-llvm` codegen extensions for extensions defined in `tket2`.
2+
pub mod bool;
23
pub mod rotation;

tket2/src/llvm/bool.rs

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
//! `hugr-llvm` codegen extension for `tket2.bool`.
2+
3+
use hugr::llvm::emit::emit_value;
4+
use hugr::llvm::emit::func::EmitFuncContext;
5+
use hugr::llvm::emit::EmitOpArgs;
6+
use hugr::llvm::inkwell;
7+
use hugr::llvm::sum::LLVMSumValue;
8+
use hugr::llvm::types::TypingSession;
9+
use hugr::llvm::CodegenExtension;
10+
use hugr::ops::ExtensionOp;
11+
use hugr::ops::Value;
12+
use hugr::types::SumType;
13+
use hugr::types::TypeName;
14+
use hugr::HugrView;
15+
use hugr::Node;
16+
17+
use crate::extension::bool::{BoolOp, ConstBool, BOOL_EXTENSION_ID};
18+
use anyhow::{anyhow, Result};
19+
use inkwell::types::IntType;
20+
use inkwell::IntPredicate;
21+
22+
const BOOL_TYPE_ID: TypeName = TypeName::new_inline("bool");
23+
24+
fn llvm_bool_type<'c>(ts: &TypingSession<'c, '_>) -> IntType<'c> {
25+
ts.iw_context().bool_type()
26+
}
27+
28+
/// A codegen extension for the `tket2.bool` extension.
29+
#[derive(Clone)]
30+
pub struct BoolCodegenExtension;
31+
32+
impl BoolCodegenExtension {
33+
fn emit_bool_op<'c, H: HugrView<Node = Node>>(
34+
&self,
35+
context: &mut EmitFuncContext<'c, '_, H>,
36+
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
37+
op: BoolOp,
38+
) -> Result<()> {
39+
match op {
40+
BoolOp::read => {
41+
let [inp] = args
42+
.inputs
43+
.try_into()
44+
.map_err(|_| anyhow!("BoolOp::read expects one argument"))?;
45+
let res = inp.into_int_value();
46+
let true_val = emit_value(context, &Value::true_val())?;
47+
let false_val = emit_value(context, &Value::false_val())?;
48+
let res = context
49+
.builder()
50+
.build_select(res, true_val, false_val, "")?;
51+
args.outputs.finish(context.builder(), vec![res])
52+
}
53+
BoolOp::make_opaque => {
54+
let [inp] = args
55+
.inputs
56+
.try_into()
57+
.map_err(|_| anyhow!("BoolOp::make_opaque expects one argument"))?;
58+
let bool_ty = context.llvm_sum_type(SumType::new_unary(2))?;
59+
let bool_val = LLVMSumValue::try_new(inp, bool_ty)?;
60+
let res = bool_val.build_get_tag(context.builder())?;
61+
args.outputs.finish(context.builder(), vec![res.into()])
62+
}
63+
BoolOp::not => {
64+
let [inp] = args
65+
.inputs
66+
.try_into()
67+
.map_err(|_| anyhow!("BoolOp::not expects one argument"))?;
68+
let res = inp.into_int_value();
69+
let res = context.builder().build_not(res, "")?;
70+
args.outputs.finish(context.builder(), vec![res.into()])
71+
}
72+
binary_op => {
73+
let [inp1, inp2] = args
74+
.inputs
75+
.try_into()
76+
.map_err(|_| anyhow!("BoolOp::{:?} expects two arguments", binary_op))?;
77+
let inp1_val = inp1.into_int_value();
78+
let inp2_val = inp2.into_int_value();
79+
let res = match binary_op {
80+
BoolOp::and => context.builder().build_and(inp1_val, inp2_val, "")?,
81+
BoolOp::or => context.builder().build_or(inp1_val, inp2_val, "")?,
82+
BoolOp::xor => context.builder().build_xor(inp1_val, inp2_val, "")?,
83+
BoolOp::eq => context.builder().build_int_compare(
84+
IntPredicate::EQ,
85+
inp1_val,
86+
inp2_val,
87+
"",
88+
)?,
89+
_ => return Err(anyhow!("Unsupported binary bool operation")),
90+
};
91+
args.outputs.finish(context.builder(), vec![res.into()])
92+
}
93+
}
94+
}
95+
}
96+
97+
impl CodegenExtension for BoolCodegenExtension {
98+
fn add_extension<'a, H: hugr::HugrView<Node = hugr::Node> + 'a>(
99+
self,
100+
builder: hugr::llvm::CodegenExtsBuilder<'a, H>,
101+
) -> hugr::llvm::CodegenExtsBuilder<'a, H>
102+
where
103+
Self: 'a,
104+
{
105+
builder
106+
.custom_type((BOOL_EXTENSION_ID, BOOL_TYPE_ID), |ts, _| {
107+
Ok(llvm_bool_type(&ts).into())
108+
})
109+
.custom_const::<ConstBool>(|context, val| {
110+
let bool_ty = llvm_bool_type(&context.typing_session());
111+
Ok(bool_ty.const_int(val.value().into(), false).into())
112+
})
113+
.simple_extension_op(move |context, args, op| self.emit_bool_op(context, args, op))
114+
}
115+
}
116+
117+
#[cfg(test)]
118+
mod test {
119+
use rstest::rstest;
120+
121+
use super::*;
122+
123+
use hugr::extension::simple_op::MakeRegisteredOp;
124+
use hugr::llvm::check_emission;
125+
use hugr::llvm::extension::DefaultPreludeCodegen;
126+
use hugr::llvm::test::{llvm_ctx, single_op_hugr, TestContext};
127+
128+
#[rstest]
129+
#[case::read(1, BoolOp::read)]
130+
#[case::make_opaque(2, BoolOp::make_opaque)]
131+
#[case::not(3, BoolOp::not)]
132+
#[case::and(4, BoolOp::and)]
133+
#[case::or(5, BoolOp::or)]
134+
#[case::xor(6, BoolOp::xor)]
135+
#[case::eq(7, BoolOp::eq)]
136+
fn emit_all_ops(#[case] _id: i32, #[with(_id)] mut llvm_ctx: TestContext, #[case] op: BoolOp) {
137+
let pcg = DefaultPreludeCodegen;
138+
llvm_ctx.add_extensions(move |ceb| {
139+
ceb.add_extension(BoolCodegenExtension)
140+
.add_prelude_extensions(pcg.clone())
141+
.add_default_int_extensions()
142+
});
143+
let ext_op = op.to_extension_op().unwrap().into();
144+
let hugr = single_op_hugr(ext_op);
145+
check_emission!(hugr, llvm_ctx);
146+
}
147+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
source: tket2/src/llvm/bool.rs
3+
expression: mod_str
4+
---
5+
; ModuleID = 'test_context'
6+
source_filename = "test_context"
7+
8+
define i1 @_hl.main.1(i1 %0) {
9+
alloca_block:
10+
br label %entry_block
11+
12+
entry_block: ; preds = %alloca_block
13+
%1 = select i1 %0, i1 true, i1 false
14+
ret i1 %1
15+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
---
2+
source: tket2/src/llvm/bool.rs
3+
expression: mod_str
4+
---
5+
; ModuleID = 'test_context'
6+
source_filename = "test_context"
7+
8+
define i1 @_hl.main.1(i1 %0) {
9+
alloca_block:
10+
br label %entry_block
11+
12+
entry_block: ; preds = %alloca_block
13+
ret i1 %0
14+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
source: tket2/src/llvm/bool.rs
3+
expression: mod_str
4+
---
5+
; ModuleID = 'test_context'
6+
source_filename = "test_context"
7+
8+
define i1 @_hl.main.1(i1 %0) {
9+
alloca_block:
10+
br label %entry_block
11+
12+
entry_block: ; preds = %alloca_block
13+
%1 = xor i1 %0, true
14+
ret i1 %1
15+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
source: tket2/src/llvm/bool.rs
3+
expression: mod_str
4+
---
5+
; ModuleID = 'test_context'
6+
source_filename = "test_context"
7+
8+
define i1 @_hl.main.1(i1 %0, i1 %1) {
9+
alloca_block:
10+
br label %entry_block
11+
12+
entry_block: ; preds = %alloca_block
13+
%2 = and i1 %0, %1
14+
ret i1 %2
15+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
source: tket2/src/llvm/bool.rs
3+
expression: mod_str
4+
---
5+
; ModuleID = 'test_context'
6+
source_filename = "test_context"
7+
8+
define i1 @_hl.main.1(i1 %0, i1 %1) {
9+
alloca_block:
10+
br label %entry_block
11+
12+
entry_block: ; preds = %alloca_block
13+
%2 = or i1 %0, %1
14+
ret i1 %2
15+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
source: tket2/src/llvm/bool.rs
3+
expression: mod_str
4+
---
5+
; ModuleID = 'test_context'
6+
source_filename = "test_context"
7+
8+
define i1 @_hl.main.1(i1 %0, i1 %1) {
9+
alloca_block:
10+
br label %entry_block
11+
12+
entry_block: ; preds = %alloca_block
13+
%2 = xor i1 %0, %1
14+
ret i1 %2
15+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
source: tket2/src/llvm/bool.rs
3+
expression: mod_str
4+
---
5+
; ModuleID = 'test_context'
6+
source_filename = "test_context"
7+
8+
define i1 @_hl.main.1(i1 %0, i1 %1) {
9+
alloca_block:
10+
br label %entry_block
11+
12+
entry_block: ; preds = %alloca_block
13+
%2 = icmp eq i1 %0, %1
14+
ret i1 %2
15+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
---
2+
source: tket2/src/llvm/bool.rs
3+
expression: mod_str
4+
---
5+
; ModuleID = 'test_context'
6+
source_filename = "test_context"
7+
8+
define i1 @_hl.main.1(i1 %0) {
9+
alloca_block:
10+
%"0" = alloca i1, align 1
11+
%"2_0" = alloca i1, align 1
12+
%"4_0" = alloca i1, align 1
13+
br label %entry_block
14+
15+
entry_block: ; preds = %alloca_block
16+
store i1 %0, i1* %"2_0", align 1
17+
%"2_01" = load i1, i1* %"2_0", align 1
18+
%1 = select i1 %"2_01", i1 true, i1 false
19+
store i1 %1, i1* %"4_0", align 1
20+
%"4_02" = load i1, i1* %"4_0", align 1
21+
store i1 %"4_02", i1* %"0", align 1
22+
%"03" = load i1, i1* %"0", align 1
23+
ret i1 %"03"
24+
}

0 commit comments

Comments
 (0)