20
20
//! contract. Since a contract can call other contracts, we need a way of restoring the counter after every execution.
21
21
//!
22
22
//! See `cairo-native-run` for an example on how to do it.
23
- use std:: { collections:: HashSet , os :: raw :: c_void , ptr } ;
23
+ use std:: collections:: HashSet ;
24
24
25
25
use melior:: {
26
26
dialect:: { llvm, memref, ods} ,
27
27
ir:: {
28
28
attribute:: { FlatSymbolRefAttribute , StringAttribute , TypeAttribute } ,
29
- operation:: OperationBuilder ,
30
29
r#type:: { IntegerType , MemRefType } ,
31
30
Attribute , Block , BlockLike , Location , Module , Region , Value ,
32
31
} ,
@@ -41,27 +40,15 @@ use crate::{
41
40
42
41
#[ derive( Clone , Copy , Debug , Hash , PartialEq , Eq ) ]
43
42
pub enum LibfuncCounterBinding {
44
- StoreArrayCounter ,
45
43
CounterId ,
46
- ArrayCounter ,
44
+ CounterArray ,
47
45
}
48
46
49
47
impl LibfuncCounterBinding {
50
48
pub const fn symbol ( self ) -> & ' static str {
51
49
match self {
52
- LibfuncCounterBinding :: StoreArrayCounter => "cairo_native__store_array_counter" ,
53
50
LibfuncCounterBinding :: CounterId => "cairo_native__counter_id" ,
54
- LibfuncCounterBinding :: ArrayCounter => "cairo_native__array_counter" ,
55
- }
56
- }
57
-
58
- const fn function_ptr ( self ) -> * const ( ) {
59
- match self {
60
- LibfuncCounterBinding :: StoreArrayCounter => {
61
- libfunc_counter_runtime:: store_array_counter as * const ( )
62
- }
63
- LibfuncCounterBinding :: CounterId => ptr:: null ( ) ,
64
- LibfuncCounterBinding :: ArrayCounter => ptr:: null ( ) ,
51
+ LibfuncCounterBinding :: CounterArray => "cairo_native__counter_array" ,
65
52
}
66
53
}
67
54
}
@@ -155,68 +142,24 @@ impl LibfuncCounterMeta {
155
142
block. append_op_result ( memref:: load ( libfunc_counter_id_ptr, & [ ] , location) )
156
143
}
157
144
158
- /// Indexes the array of counters and increments the counter relative
159
- /// to the given libfunc index
160
- pub fn store_array_counter (
161
- & mut self ,
162
- context : & Context ,
163
- module : & Module ,
164
- block : & Block < ' _ > ,
165
- location : Location ,
166
- libfunc_amount : u32 ,
167
- ) -> Result < ( ) > {
168
- let counter_id = self . build_counter_id ( context, module, block, location) ?;
169
- let function_ptr = self . build_function (
170
- context,
171
- module,
172
- block,
173
- location,
174
- LibfuncCounterBinding :: StoreArrayCounter ,
175
- ) ?;
176
- let lifuncs_amount = block. const_int ( context, location, libfunc_amount, 32 ) ?;
177
- // by this time, the array counter should be initialized
178
- let array_counter_ptr_ptr = block. append_op_result (
179
- ods:: llvm:: mlir_addressof (
180
- context,
181
- llvm:: r#type:: pointer ( context, 0 ) ,
182
- FlatSymbolRefAttribute :: new ( context, LibfuncCounterBinding :: ArrayCounter . symbol ( ) ) ,
183
- location,
184
- )
185
- . into ( ) ,
186
- ) ?;
187
- let array_counter_ptr = block. load (
188
- context,
189
- location,
190
- array_counter_ptr_ptr,
191
- llvm:: r#type:: pointer ( context, 0 ) ,
192
- ) ?;
193
-
194
- block. append_operation (
195
- OperationBuilder :: new ( "llvm.call" , location)
196
- . add_operands ( & [ function_ptr] )
197
- . add_operands ( & [ counter_id, array_counter_ptr, lifuncs_amount] )
198
- . build ( ) ?,
199
- ) ;
200
-
201
- Ok ( ( ) )
202
- }
203
-
204
145
/// Build the array of counters
205
- fn get_array_counter < ' c , ' a > (
146
+ fn build_array_counter < ' c , ' a > (
206
147
& mut self ,
207
148
context : & ' c Context ,
208
149
module : & Module ,
209
150
block : & ' a Block < ' c > ,
210
151
location : Location < ' c > ,
211
152
libfunc_amount : u32 ,
212
153
) -> Result < Value < ' c , ' a > > {
213
- if self . active_map . insert ( LibfuncCounterBinding :: ArrayCounter ) {
154
+ if self . active_map . insert ( LibfuncCounterBinding :: CounterArray ) {
155
+ self . build_counter_id ( context, module, block, location) ?;
156
+
214
157
module. body ( ) . append_operation (
215
158
ods:: llvm:: mlir_global (
216
159
context,
217
160
Region :: new ( ) ,
218
161
TypeAttribute :: new ( llvm:: r#type:: pointer ( context, 0 ) ) ,
219
- StringAttribute :: new ( context, LibfuncCounterBinding :: ArrayCounter . symbol ( ) ) ,
162
+ StringAttribute :: new ( context, LibfuncCounterBinding :: CounterArray . symbol ( ) ) ,
220
163
Attribute :: parse ( context, "#llvm.linkage<weak>" )
221
164
. ok_or ( Error :: ParseAttributeError ) ?,
222
165
location,
@@ -240,7 +183,7 @@ impl LibfuncCounterMeta {
240
183
llvm:: r#type:: pointer ( context, 0 ) ,
241
184
FlatSymbolRefAttribute :: new (
242
185
context,
243
- LibfuncCounterBinding :: ArrayCounter . symbol ( ) ,
186
+ LibfuncCounterBinding :: CounterArray . symbol ( ) ,
244
187
) ,
245
188
location,
246
189
)
@@ -264,13 +207,13 @@ impl LibfuncCounterMeta {
264
207
ods:: llvm:: mlir_addressof (
265
208
context,
266
209
llvm:: r#type:: pointer ( context, 0 ) ,
267
- FlatSymbolRefAttribute :: new ( context, LibfuncCounterBinding :: ArrayCounter . symbol ( ) ) ,
210
+ FlatSymbolRefAttribute :: new ( context, LibfuncCounterBinding :: CounterArray . symbol ( ) ) ,
268
211
location,
269
212
)
270
213
. into ( ) ,
271
214
) ?;
272
215
273
- // // return the pointer to array counter
216
+ // return the pointer to array counter
274
217
block. load (
275
218
context,
276
219
location,
@@ -289,10 +232,11 @@ impl LibfuncCounterMeta {
289
232
libfuncs_amount : u32 ,
290
233
) -> Result < ( ) > {
291
234
let u32_ty = IntegerType :: new ( context, 32 ) . into ( ) ;
292
- let k1 = block. const_int ( context, location, 0 , 32 ) ?;
235
+ let k1 = block. const_int ( context, location, 1 , 32 ) ?;
293
236
294
237
let array_counter_ptr =
295
- self . get_array_counter ( context, module, block, location, libfuncs_amount) ?;
238
+ self . build_array_counter ( context, module, block, location, libfuncs_amount) ?;
239
+
296
240
let value_counter_ptr = block. gep (
297
241
context,
298
242
location,
@@ -310,24 +254,13 @@ impl LibfuncCounterMeta {
310
254
}
311
255
}
312
256
313
- pub fn setup_runtime ( find_symbol_ptr : impl Fn ( & str ) -> Option < * mut c_void > ) {
314
- let bindings = & [ LibfuncCounterBinding :: StoreArrayCounter ] ;
315
-
316
- for binding in bindings {
317
- if let Some ( global) = find_symbol_ptr ( binding. symbol ( ) ) {
318
- let global = global. cast :: < * const ( ) > ( ) ;
319
- unsafe { * global = binding. function_ptr ( ) } ;
320
- }
321
- }
322
- }
323
-
324
257
pub mod libfunc_counter_runtime {
258
+ use core:: slice;
325
259
use std:: {
326
260
collections:: HashMap ,
327
261
sync:: { LazyLock , Mutex } ,
328
262
} ;
329
263
330
- use itertools:: Itertools ;
331
264
use melior:: {
332
265
ir:: { Block , Location , Module } ,
333
266
Context ,
@@ -364,16 +297,18 @@ pub mod libfunc_counter_runtime {
364
297
)
365
298
}
366
299
367
- pub unsafe extern "C" fn store_array_counter (
368
- counter_id : u64 ,
369
- array_counter : * const u32 ,
370
- libfuncs_amount : u32 ,
300
+ pub unsafe fn store_counters_array (
301
+ counter_id_ptr : * mut u64 ,
302
+ array_ptr_ptr : * mut * mut u32 ,
303
+ libfuncs_amount : usize ,
371
304
) {
372
- let mut libfunc_counter = LIBFUNC_COUNTER . lock ( ) . unwrap ( ) ;
373
- let vec = ( 0 ..libfuncs_amount)
374
- . map ( |i| * array_counter. add ( i as usize ) )
375
- . collect_vec ( ) ;
305
+ let counters_vec = slice:: from_raw_parts ( * array_ptr_ptr, libfuncs_amount) . to_vec ( ) ;
306
+
307
+ LIBFUNC_COUNTER
308
+ . lock ( )
309
+ . unwrap ( )
310
+ . insert ( * counter_id_ptr, counters_vec) ;
376
311
377
- libfunc_counter . insert ( counter_id , vec ) ;
312
+ libc :: free ( * array_ptr_ptr . cast :: < * mut libc :: c_void > ( ) ) ;
378
313
}
379
314
}
0 commit comments