Skip to content

Commit f8f1460

Browse files
fix value increment
1 parent f7b1aeb commit f8f1460

File tree

6 files changed

+86
-124
lines changed

6 files changed

+86
-124
lines changed

src/bin/cairo-native-run.rs

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,41 @@ fn main() -> anyhow::Result<()> {
135135
unsafe { *counter_id = 0 };
136136
}
137137
}
138+
139+
#[cfg(feature = "with-libfunc-counter")]
140+
let libfuncs_amount = sierra_program.libfunc_declarations.len();
138141

139142
Box::new(move |function_id, args, gas, syscall_handler| {
140-
executor.invoke_dynamic_with_syscall_handler(
143+
let result = executor.invoke_dynamic_with_syscall_handler(
141144
function_id,
142145
args,
143146
gas,
144147
syscall_handler,
145-
)
148+
);
149+
150+
#[cfg(feature = "with-libfunc-counter")]
151+
unsafe {
152+
use cairo_native::metadata::libfunc_counter::{
153+
libfunc_counter_runtime, LibfuncCounterBinding,
154+
};
155+
156+
let counter_id_ptr = executor
157+
.find_symbol_ptr(LibfuncCounterBinding::CounterId.symbol())
158+
.unwrap()
159+
.cast::<u64>();
160+
let counters_array_ptr_ptr = executor
161+
.find_symbol_ptr(LibfuncCounterBinding::CounterArray.symbol())
162+
.unwrap()
163+
.cast::<*mut u32>();
164+
165+
libfunc_counter_runtime::store_counters_array(
166+
counter_id_ptr,
167+
counters_array_ptr_ptr,
168+
libfuncs_amount,
169+
);
170+
}
171+
172+
result
146173
})
147174
}
148175
RunMode::Jit => {
@@ -152,6 +179,7 @@ fn main() -> anyhow::Result<()> {
152179
#[cfg(feature = "with-trace-dump")]
153180
{
154181
use cairo_native::metadata::trace_dump::TraceBinding;
182+
155183
if let Some(trace_id) = executor.find_symbol_ptr(TraceBinding::TraceId.symbol()) {
156184
let trace_id = trace_id.cast::<u64>();
157185
unsafe { *trace_id = 0 };
@@ -173,6 +201,7 @@ fn main() -> anyhow::Result<()> {
173201
#[cfg(feature = "with-libfunc-counter")]
174202
{
175203
use cairo_native::metadata::libfunc_counter::LibfuncCounterBinding;
204+
176205
if let Some(counter_id) =
177206
executor.find_symbol_ptr(LibfuncCounterBinding::CounterId.symbol())
178207
{
@@ -181,13 +210,40 @@ fn main() -> anyhow::Result<()> {
181210
}
182211
}
183212

213+
#[cfg(feature = "with-libfunc-counter")]
214+
let libfuncs_amount = sierra_program.libfunc_declarations.len();
215+
184216
Box::new(move |function_id, args, gas, syscall_handler| {
185-
executor.invoke_dynamic_with_syscall_handler(
217+
let result = executor.invoke_dynamic_with_syscall_handler(
186218
function_id,
187219
args,
188220
gas,
189221
syscall_handler,
190-
)
222+
);
223+
224+
#[cfg(feature = "with-libfunc-counter")]
225+
unsafe {
226+
use cairo_native::metadata::libfunc_counter::{
227+
libfunc_counter_runtime, LibfuncCounterBinding,
228+
};
229+
230+
let counter_id_ptr = executor
231+
.find_symbol_ptr(LibfuncCounterBinding::CounterId.symbol())
232+
.unwrap()
233+
.cast::<u64>();
234+
let counters_array_ptr_ptr = executor
235+
.find_symbol_ptr(LibfuncCounterBinding::CounterArray.symbol())
236+
.unwrap()
237+
.cast::<*mut u32>();
238+
239+
libfunc_counter_runtime::store_counters_array(
240+
counter_id_ptr,
241+
counters_array_ptr_ptr,
242+
libfuncs_amount,
243+
);
244+
}
245+
246+
result
191247
})
192248
}
193249
};

src/compiler.rs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,10 +1064,6 @@ fn compile_func(
10641064
sierra_stmt_start_offset + function.entry_point.0,
10651065
0,
10661066
),
1067-
#[cfg(feature = "with-libfunc-counter")]
1068-
libfunc_indexes,
1069-
#[cfg(feature = "with-libfunc-counter")]
1070-
metadata,
10711067
)?;
10721068

10731069
tracing::debug!("Done generating function {}.", function.id);
@@ -1449,8 +1445,6 @@ fn generate_entry_point_wrapper<'c>(
14491445
arg_types: &[(Type<'c>, Location<'c>)],
14501446
ret_types: &[Type<'c>],
14511447
location: Location<'c>,
1452-
#[cfg(feature = "with-libfunc-counter")] libfunc_indexes: &HashMap<ConcreteLibfuncId, usize>,
1453-
#[cfg(feature = "with-libfunc-counter")] metadata: &mut MetadataStorage,
14541448
) -> Result<(), Error> {
14551449
let region = Region::new();
14561450
let block = region.append_block(Block::new(arg_types));
@@ -1483,20 +1477,6 @@ fn generate_entry_point_wrapper<'c>(
14831477
returns.push(block.extract_value(context, location, result, *ty, i)?);
14841478
}
14851479

1486-
#[cfg(feature = "with-libfunc-counter")]
1487-
{
1488-
use crate::metadata::libfunc_counter::LibfuncCounterMeta;
1489-
1490-
let libfunc_counter = metadata.get_mut::<LibfuncCounterMeta>().unwrap();
1491-
libfunc_counter.store_array_counter(
1492-
context,
1493-
module,
1494-
&block,
1495-
location,
1496-
libfunc_indexes.len() as u32,
1497-
)?;
1498-
}
1499-
15001480
block.append_operation(func::r#return(&returns, location));
15011481

15021482
module.body().append_operation(func::func(

src/executor/aot.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ impl AotNativeExecutor {
6363
#[cfg(feature = "with-libfunc-profiling")]
6464
crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name));
6565

66-
#[cfg(feature = "with-libfunc-counter")]
67-
crate::metadata::libfunc_counter::setup_runtime(|name| executor.find_symbol_ptr(name));
68-
6966
executor
7067
}
7168

src/executor/contract.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,6 @@ impl AotContractExecutor {
333333
#[cfg(feature = "with-libfunc-profiling")]
334334
crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name));
335335

336-
#[cfg(feature = "with-libfunc-counter")]
337-
crate::metadata::libfunc_counter::setup_runtime(|name| executor.find_symbol_ptr(name));
338-
339336
Ok(Some(executor))
340337
}
341338

src/executor/jit.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,6 @@ impl<'m> JitNativeExecutor<'m> {
7474
#[cfg(feature = "with-libfunc-profiling")]
7575
crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name));
7676

77-
#[cfg(feature = "with-libfunc-counter")]
78-
crate::metadata::libfunc_counter::setup_runtime(|name| executor.find_symbol_ptr(name));
79-
8077
Ok(executor)
8178
}
8279

src/metadata/libfunc_counter.rs

Lines changed: 26 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020
//! contract. Since a contract can call other contracts, we need a way of restoring the counter after every execution.
2121
//!
2222
//! See `cairo-native-run` for an example on how to do it.
23-
use std::{collections::HashSet, os::raw::c_void, ptr};
23+
use std::collections::HashSet;
2424

2525
use melior::{
2626
dialect::{llvm, memref, ods},
2727
ir::{
2828
attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute},
29-
operation::OperationBuilder,
3029
r#type::{IntegerType, MemRefType},
3130
Attribute, Block, BlockLike, Location, Module, Region, Value,
3231
},
@@ -41,27 +40,15 @@ use crate::{
4140

4241
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
4342
pub enum LibfuncCounterBinding {
44-
StoreArrayCounter,
4543
CounterId,
46-
ArrayCounter,
44+
CounterArray,
4745
}
4846

4947
impl LibfuncCounterBinding {
5048
pub const fn symbol(self) -> &'static str {
5149
match self {
52-
LibfuncCounterBinding::StoreArrayCounter => "cairo_native__store_array_counter",
5350
LibfuncCounterBinding::CounterId => "cairo_native__counter_id",
54-
LibfuncCounterBinding::ArrayCounter => "cairo_native__array_counter",
55-
}
56-
}
57-
58-
const fn function_ptr(self) -> *const () {
59-
match self {
60-
LibfuncCounterBinding::StoreArrayCounter => {
61-
libfunc_counter_runtime::store_array_counter as *const ()
62-
}
63-
LibfuncCounterBinding::CounterId => ptr::null(),
64-
LibfuncCounterBinding::ArrayCounter => ptr::null(),
51+
LibfuncCounterBinding::CounterArray => "cairo_native__counter_array",
6552
}
6653
}
6754
}
@@ -155,68 +142,24 @@ impl LibfuncCounterMeta {
155142
block.append_op_result(memref::load(libfunc_counter_id_ptr, &[], location))
156143
}
157144

158-
/// Indexes the array of counters and increments the counter relative
159-
/// to the given libfunc index
160-
pub fn store_array_counter(
161-
&mut self,
162-
context: &Context,
163-
module: &Module,
164-
block: &Block<'_>,
165-
location: Location,
166-
libfunc_amount: u32,
167-
) -> Result<()> {
168-
let counter_id = self.build_counter_id(context, module, block, location)?;
169-
let function_ptr = self.build_function(
170-
context,
171-
module,
172-
block,
173-
location,
174-
LibfuncCounterBinding::StoreArrayCounter,
175-
)?;
176-
let lifuncs_amount = block.const_int(context, location, libfunc_amount, 32)?;
177-
// by this time, the array counter should be initialized
178-
let array_counter_ptr_ptr = block.append_op_result(
179-
ods::llvm::mlir_addressof(
180-
context,
181-
llvm::r#type::pointer(context, 0),
182-
FlatSymbolRefAttribute::new(context, LibfuncCounterBinding::ArrayCounter.symbol()),
183-
location,
184-
)
185-
.into(),
186-
)?;
187-
let array_counter_ptr = block.load(
188-
context,
189-
location,
190-
array_counter_ptr_ptr,
191-
llvm::r#type::pointer(context, 0),
192-
)?;
193-
194-
block.append_operation(
195-
OperationBuilder::new("llvm.call", location)
196-
.add_operands(&[function_ptr])
197-
.add_operands(&[counter_id, array_counter_ptr, lifuncs_amount])
198-
.build()?,
199-
);
200-
201-
Ok(())
202-
}
203-
204145
/// Build the array of counters
205-
fn get_array_counter<'c, 'a>(
146+
fn build_array_counter<'c, 'a>(
206147
&mut self,
207148
context: &'c Context,
208149
module: &Module,
209150
block: &'a Block<'c>,
210151
location: Location<'c>,
211152
libfunc_amount: u32,
212153
) -> Result<Value<'c, 'a>> {
213-
if self.active_map.insert(LibfuncCounterBinding::ArrayCounter) {
154+
if self.active_map.insert(LibfuncCounterBinding::CounterArray) {
155+
self.build_counter_id(context, module, block, location)?;
156+
214157
module.body().append_operation(
215158
ods::llvm::mlir_global(
216159
context,
217160
Region::new(),
218161
TypeAttribute::new(llvm::r#type::pointer(context, 0)),
219-
StringAttribute::new(context, LibfuncCounterBinding::ArrayCounter.symbol()),
162+
StringAttribute::new(context, LibfuncCounterBinding::CounterArray.symbol()),
220163
Attribute::parse(context, "#llvm.linkage<weak>")
221164
.ok_or(Error::ParseAttributeError)?,
222165
location,
@@ -240,7 +183,7 @@ impl LibfuncCounterMeta {
240183
llvm::r#type::pointer(context, 0),
241184
FlatSymbolRefAttribute::new(
242185
context,
243-
LibfuncCounterBinding::ArrayCounter.symbol(),
186+
LibfuncCounterBinding::CounterArray.symbol(),
244187
),
245188
location,
246189
)
@@ -264,13 +207,13 @@ impl LibfuncCounterMeta {
264207
ods::llvm::mlir_addressof(
265208
context,
266209
llvm::r#type::pointer(context, 0),
267-
FlatSymbolRefAttribute::new(context, LibfuncCounterBinding::ArrayCounter.symbol()),
210+
FlatSymbolRefAttribute::new(context, LibfuncCounterBinding::CounterArray.symbol()),
268211
location,
269212
)
270213
.into(),
271214
)?;
272215

273-
// // return the pointer to array counter
216+
// return the pointer to array counter
274217
block.load(
275218
context,
276219
location,
@@ -289,10 +232,11 @@ impl LibfuncCounterMeta {
289232
libfuncs_amount: u32,
290233
) -> Result<()> {
291234
let u32_ty = IntegerType::new(context, 32).into();
292-
let k1 = block.const_int(context, location, 0, 32)?;
235+
let k1 = block.const_int(context, location, 1, 32)?;
293236

294237
let array_counter_ptr =
295-
self.get_array_counter(context, module, block, location, libfuncs_amount)?;
238+
self.build_array_counter(context, module, block, location, libfuncs_amount)?;
239+
296240
let value_counter_ptr = block.gep(
297241
context,
298242
location,
@@ -310,24 +254,13 @@ impl LibfuncCounterMeta {
310254
}
311255
}
312256

313-
pub fn setup_runtime(find_symbol_ptr: impl Fn(&str) -> Option<*mut c_void>) {
314-
let bindings = &[LibfuncCounterBinding::StoreArrayCounter];
315-
316-
for binding in bindings {
317-
if let Some(global) = find_symbol_ptr(binding.symbol()) {
318-
let global = global.cast::<*const ()>();
319-
unsafe { *global = binding.function_ptr() };
320-
}
321-
}
322-
}
323-
324257
pub mod libfunc_counter_runtime {
258+
use core::slice;
325259
use std::{
326260
collections::HashMap,
327261
sync::{LazyLock, Mutex},
328262
};
329263

330-
use itertools::Itertools;
331264
use melior::{
332265
ir::{Block, Location, Module},
333266
Context,
@@ -364,16 +297,18 @@ pub mod libfunc_counter_runtime {
364297
)
365298
}
366299

367-
pub unsafe extern "C" fn store_array_counter(
368-
counter_id: u64,
369-
array_counter: *const u32,
370-
libfuncs_amount: u32,
300+
pub unsafe fn store_counters_array(
301+
counter_id_ptr: *mut u64,
302+
array_ptr_ptr: *mut *mut u32,
303+
libfuncs_amount: usize,
371304
) {
372-
let mut libfunc_counter = LIBFUNC_COUNTER.lock().unwrap();
373-
let vec = (0..libfuncs_amount)
374-
.map(|i| *array_counter.add(i as usize))
375-
.collect_vec();
305+
let counters_vec = slice::from_raw_parts(*array_ptr_ptr, libfuncs_amount).to_vec();
306+
307+
LIBFUNC_COUNTER
308+
.lock()
309+
.unwrap()
310+
.insert(*counter_id_ptr, counters_vec);
376311

377-
libfunc_counter.insert(counter_id, vec);
312+
libc::free(*array_ptr_ptr.cast::<*mut libc::c_void>());
378313
}
379314
}

0 commit comments

Comments
 (0)