Skip to content

Commit 0dd4ebd

Browse files
authored
fix: run QystemPass with module as entrypoint (#945)
Detected in selene: type lowering was not working correctly with function entrypoint Moves the backwards compatibility code to entrypoint computation
1 parent 7a4dc21 commit 0dd4ebd

File tree

1 file changed

+37
-19
lines changed

1 file changed

+37
-19
lines changed

tket2-hseries/src/lib.rs

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use hugr::{
1010
ComposablePass as _, LinearizeArrayPass, MonomorphizePass, RemoveDeadFuncsError,
1111
RemoveDeadFuncsPass,
1212
},
13-
hugr::HugrError,
13+
hugr::{hugrmut::HugrMut, HugrError},
1414
Hugr, HugrView, Node,
1515
};
1616
use replace_bools::{ReplaceBoolPass, ReplaceBoolPassError};
@@ -82,22 +82,28 @@ pub enum QSystemPassError<N = Node> {
8282
impl QSystemPass {
8383
/// Run `QSystemPass` on the given [Hugr]. `registry` is used for
8484
/// validation, if enabled.
85+
/// Expects the HUGR to have a function entrypoint.
8586
pub fn run(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> {
87+
let entrypoint = if hugr.entrypoint_optype().is_module() {
88+
// backwards compatibility: if the entrypoint is a module, we look for
89+
// a function named "main" in the module and use that as the entrypoint.
90+
hugr.children(hugr.entrypoint())
91+
.find(|&n| {
92+
hugr.get_optype(n)
93+
.as_func_defn()
94+
.is_some_and(|fd| fd.func_name() == "main")
95+
})
96+
.ok_or(QSystemPassError::NoMain)?
97+
} else {
98+
hugr.entrypoint()
99+
};
100+
101+
// passes that run on whole module
102+
hugr.set_entrypoint(hugr.module_root());
86103
if self.monomorphize {
87104
self.monomorphization().run(hugr).unwrap();
88105

89-
let mut rdfp = RemoveDeadFuncsPass::default();
90-
if hugr.entrypoint_optype().is_module() {
91-
let main_node = hugr
92-
.children(hugr.entrypoint())
93-
.find(|&n| {
94-
hugr.get_optype(n)
95-
.as_func_defn()
96-
.is_some_and(|fd| fd.func_name() == "main")
97-
})
98-
.ok_or(QSystemPassError::NoMain)?;
99-
rdfp = rdfp.with_module_entry_points([main_node]);
100-
}
106+
let rdfp = RemoveDeadFuncsPass::default().with_module_entry_points([entrypoint]);
101107
rdfp.run(hugr)?
102108
}
103109

@@ -112,6 +118,8 @@ impl QSystemPass {
112118
if self.force_order {
113119
self.force_order(hugr)?;
114120
}
121+
// restore the entrypoint
122+
hugr.set_entrypoint(entrypoint);
115123
Ok(())
116124
}
117125

@@ -220,6 +228,7 @@ mod test {
220228
use hugr::{
221229
builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder},
222230
extension::prelude::qb_t,
231+
hugr::hugrmut::HugrMut,
223232
ops::handle::NodeHandle,
224233
std_extensions::arithmetic::float_types::ConstF64,
225234
type_row,
@@ -229,23 +238,26 @@ mod test {
229238

230239
use itertools::Itertools as _;
231240
use petgraph::visit::{Topo, Walker as _};
241+
use rstest::rstest;
232242
use tket2::extension::bool::bool_type;
233243

234244
use crate::{
235245
extension::{futures::FutureOpDef, qsystem::QSystemOp},
236246
QSystemPass,
237247
};
238248

239-
#[test]
240-
fn qsystem_pass() {
249+
#[rstest]
250+
#[case(false)]
251+
#[case(true)]
252+
fn qsystem_pass(#[case] set_entrypoint: bool) {
241253
let mut mb = hugr::builder::ModuleBuilder::new();
242254
let func = mb
243255
.define_function("func", Signature::new_endo(type_row![]))
244256
.unwrap()
245257
.finish_with_outputs([])
246258
.unwrap();
247259

248-
let (mut hugr, [call_node, h_node, f_node, rx_node]) = {
260+
let (mut hugr, [call_node, h_node, f_node, rx_node, main_node]) = {
249261
let mut builder = mb
250262
.define_function(
251263
"main",
@@ -284,12 +296,18 @@ mod test {
284296
.unwrap()
285297
.outputs_arr();
286298

287-
let _main_n = builder
299+
let main_n = builder
288300
.finish_with_outputs([measure_result, measure_result])
289-
.unwrap();
301+
.unwrap()
302+
.node();
290303
let hugr = mb.finish_hugr().unwrap();
291-
(hugr, [call_node, h_node, f_node, rx_node])
304+
(hugr, [call_node, h_node, f_node, rx_node, main_n])
292305
};
306+
if set_entrypoint {
307+
// set the entrypoint to the main function
308+
// if this is not done the "backwards compatibility" code is triggered
309+
hugr.set_entrypoint(main_node);
310+
}
293311
QSystemPass::default().run(&mut hugr).unwrap();
294312

295313
let topo_sorted = Topo::new(&hugr.as_petgraph())

0 commit comments

Comments
 (0)