@@ -10,7 +10,7 @@ use hugr::{
10
10
ComposablePass as _, LinearizeArrayPass , MonomorphizePass , RemoveDeadFuncsError ,
11
11
RemoveDeadFuncsPass ,
12
12
} ,
13
- hugr:: HugrError ,
13
+ hugr:: { hugrmut :: HugrMut , HugrError } ,
14
14
Hugr , HugrView , Node ,
15
15
} ;
16
16
use replace_bools:: { ReplaceBoolPass , ReplaceBoolPassError } ;
@@ -82,22 +82,28 @@ pub enum QSystemPassError<N = Node> {
82
82
impl QSystemPass {
83
83
/// Run `QSystemPass` on the given [Hugr]. `registry` is used for
84
84
/// validation, if enabled.
85
+ /// Expects the HUGR to have a function entrypoint.
85
86
pub fn run ( & self , hugr : & mut Hugr ) -> Result < ( ) , QSystemPassError > {
87
+ let entrypoint = if hugr. entrypoint_optype ( ) . is_module ( ) {
88
+ // backwards compatibility: if the entrypoint is a module, we look for
89
+ // a function named "main" in the module and use that as the entrypoint.
90
+ hugr. children ( hugr. entrypoint ( ) )
91
+ . find ( |& n| {
92
+ hugr. get_optype ( n)
93
+ . as_func_defn ( )
94
+ . is_some_and ( |fd| fd. func_name ( ) == "main" )
95
+ } )
96
+ . ok_or ( QSystemPassError :: NoMain ) ?
97
+ } else {
98
+ hugr. entrypoint ( )
99
+ } ;
100
+
101
+ // passes that run on whole module
102
+ hugr. set_entrypoint ( hugr. module_root ( ) ) ;
86
103
if self . monomorphize {
87
104
self . monomorphization ( ) . run ( hugr) . unwrap ( ) ;
88
105
89
- let mut rdfp = RemoveDeadFuncsPass :: default ( ) ;
90
- if hugr. entrypoint_optype ( ) . is_module ( ) {
91
- let main_node = hugr
92
- . children ( hugr. entrypoint ( ) )
93
- . find ( |& n| {
94
- hugr. get_optype ( n)
95
- . as_func_defn ( )
96
- . is_some_and ( |fd| fd. func_name ( ) == "main" )
97
- } )
98
- . ok_or ( QSystemPassError :: NoMain ) ?;
99
- rdfp = rdfp. with_module_entry_points ( [ main_node] ) ;
100
- }
106
+ let rdfp = RemoveDeadFuncsPass :: default ( ) . with_module_entry_points ( [ entrypoint] ) ;
101
107
rdfp. run ( hugr) ?
102
108
}
103
109
@@ -112,6 +118,8 @@ impl QSystemPass {
112
118
if self . force_order {
113
119
self . force_order ( hugr) ?;
114
120
}
121
+ // restore the entrypoint
122
+ hugr. set_entrypoint ( entrypoint) ;
115
123
Ok ( ( ) )
116
124
}
117
125
@@ -220,6 +228,7 @@ mod test {
220
228
use hugr:: {
221
229
builder:: { Container , Dataflow , DataflowSubContainer , HugrBuilder } ,
222
230
extension:: prelude:: qb_t,
231
+ hugr:: hugrmut:: HugrMut ,
223
232
ops:: handle:: NodeHandle ,
224
233
std_extensions:: arithmetic:: float_types:: ConstF64 ,
225
234
type_row,
@@ -229,23 +238,26 @@ mod test {
229
238
230
239
use itertools:: Itertools as _;
231
240
use petgraph:: visit:: { Topo , Walker as _} ;
241
+ use rstest:: rstest;
232
242
use tket2:: extension:: bool:: bool_type;
233
243
234
244
use crate :: {
235
245
extension:: { futures:: FutureOpDef , qsystem:: QSystemOp } ,
236
246
QSystemPass ,
237
247
} ;
238
248
239
- #[ test]
240
- fn qsystem_pass ( ) {
249
+ #[ rstest]
250
+ #[ case( false ) ]
251
+ #[ case( true ) ]
252
+ fn qsystem_pass ( #[ case] set_entrypoint : bool ) {
241
253
let mut mb = hugr:: builder:: ModuleBuilder :: new ( ) ;
242
254
let func = mb
243
255
. define_function ( "func" , Signature :: new_endo ( type_row ! [ ] ) )
244
256
. unwrap ( )
245
257
. finish_with_outputs ( [ ] )
246
258
. unwrap ( ) ;
247
259
248
- let ( mut hugr, [ call_node, h_node, f_node, rx_node] ) = {
260
+ let ( mut hugr, [ call_node, h_node, f_node, rx_node, main_node ] ) = {
249
261
let mut builder = mb
250
262
. define_function (
251
263
"main" ,
@@ -284,12 +296,18 @@ mod test {
284
296
. unwrap ( )
285
297
. outputs_arr ( ) ;
286
298
287
- let _main_n = builder
299
+ let main_n = builder
288
300
. finish_with_outputs ( [ measure_result, measure_result] )
289
- . unwrap ( ) ;
301
+ . unwrap ( )
302
+ . node ( ) ;
290
303
let hugr = mb. finish_hugr ( ) . unwrap ( ) ;
291
- ( hugr, [ call_node, h_node, f_node, rx_node] )
304
+ ( hugr, [ call_node, h_node, f_node, rx_node, main_n ] )
292
305
} ;
306
+ if set_entrypoint {
307
+ // set the entrypoint to the main function
308
+ // if this is not done the "backwards compatibility" code is triggered
309
+ hugr. set_entrypoint ( main_node) ;
310
+ }
293
311
QSystemPass :: default ( ) . run ( & mut hugr) . unwrap ( ) ;
294
312
295
313
let topo_sorted = Topo :: new ( & hugr. as_petgraph ( ) )
0 commit comments