Skip to content

MLIR implementation of euclidean algorithm #1357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 129 additions & 1 deletion src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use itertools::Itertools;
use melior::{
dialect::{
arith::CmpiPredicate,
arith::{self, CmpiPredicate},
cf, func, index,
llvm::{self, LoadStoreOptions},
memref,
Expand Down Expand Up @@ -141,6 +141,26 @@ pub fn compile(
}
}

let location = Location::unknown(context); // TODO: WHICH LOCATION SHOULD I USE?
let integer_type: Type = IntegerType::new(context, 384 * 2).into();
let region = declare_euclidean_func(context, location, integer_type);
let func_name = StringAttribute::new(context, "cairo_native__euclidean_algorithm");
module.body().append_operation(llvm::func(
context,
func_name,
TypeAttribute::new(llvm::r#type::function(
llvm::r#type::r#struct(context, &[integer_type, integer_type], false),
&[integer_type, integer_type],
false,
)),
region,
&[(
Identifier::new(context, "no_inline"),
Attribute::unit(context),
)],
location,
));

// Sierra programs have the following structure:
// 1. Type declarations, one per line.
// 2. Libfunc declarations, one per line.
Expand Down Expand Up @@ -171,6 +191,114 @@ pub fn compile(
Ok(())
}

fn declare_euclidean_func<'ctx>(
context: &'ctx Context,
location: Location<'ctx>,
integer_type: Type<'_>,
) -> Region<'ctx> {
let region = Region::new();

let entry_block = region.append_block(Block::new(&[
(integer_type, location),
(integer_type, location),
]));

// The algorithm egcd works by calculating a series of remainders, each the remainder of dividing the previous two
// For the initial setup, r0 = b, r1 = a
// This order is chosen because if we reverse them, then the first iteration will just swap them
let remainder = entry_block.arg(0).unwrap();
let prev_remainder = entry_block.arg(1).unwrap();

// Similarly we'll calculate another series which starts 0,1,... and from which we will retrieve the modular inverse of a
let prev_inverse = entry_block
.const_int_from_type(context, location, 0, integer_type)
.unwrap();
let inverse = entry_block
.const_int_from_type(context, location, 1, integer_type)
.unwrap();

let loop_block = region.append_block(Block::new(&[
(integer_type, location),
(integer_type, location),
(integer_type, location),
(integer_type, location),
]));
let end_block = region.append_block(Block::new(&[
(integer_type, location),
(integer_type, location),
]));

entry_block.append_operation(cf::br(
&loop_block,
&[prev_remainder, remainder, prev_inverse, inverse],
location,
));

// -- Loop body --
// Arguments are rem_(i-1), rem, inv_(i-1), inv
let prev_remainder = loop_block.arg(0).unwrap();
let remainder = loop_block.arg(1).unwrap();
let prev_inverse = loop_block.arg(2).unwrap();
let inverse = loop_block.arg(3).unwrap();

// First calculate q = rem_(i-1)/rem_i, rounded down
let quotient = loop_block
.append_op_result(arith::divui(prev_remainder, remainder, location))
.unwrap();

// Then r_(i+1) = r_(i-1) - q * r_i, and inv_(i+1) = inv_(i-1) - q * inv_i
let rem_times_quo = loop_block.muli(remainder, quotient, location).unwrap();
let inv_times_quo = loop_block.muli(inverse, quotient, location).unwrap();
let next_remainder = loop_block
.append_op_result(arith::subi(prev_remainder, rem_times_quo, location))
.unwrap();
let next_inverse = loop_block
.append_op_result(arith::subi(prev_inverse, inv_times_quo, location))
.unwrap();

// Check if r_(i+1) is 0
// If true, then:
// - r_i is the gcd of a and b
// - inv_i is the bezout coefficient x

let zero = loop_block
.const_int_from_type(context, location, 0, integer_type)
.unwrap();
let next_remainder_eq_zero = loop_block
.cmpi(context, CmpiPredicate::Eq, next_remainder, zero, location)
.unwrap();
loop_block.append_operation(cf::cond_br(
context,
next_remainder_eq_zero,
&end_block,
&loop_block,
&[remainder, inverse],
&[remainder, next_remainder, inverse, next_inverse],
location,
));
// loop_block.append_operation(cf::br(&end_block, &[remainder, inverse], location));

//////// SAME AS libsfuncs/array.rs line 213 ////////
let results = end_block
.append_op_result(llvm::undef(
llvm::r#type::r#struct(context, &[integer_type, integer_type], false),
location,
))
.unwrap();
let results = end_block
.insert_values(
context,
location,
results,
&[end_block.arg(0).unwrap(), end_block.arg(1).unwrap()],
)
.unwrap();
////////////////////////////////////////////////
end_block.append_operation(llvm::r#return(Some(results), location));

region
}

/// Compile a single Sierra function.
///
/// The function accepts a `Function` argument, which provides the function's entry point, signature
Expand Down
151 changes: 61 additions & 90 deletions src/libfuncs/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ use melior::{
cf, llvm,
},
helpers::{ArithBlockExt, BuiltinBlockExt, GepIndex, LlvmBlockExt},
ir::{r#type::IntegerType, Block, BlockLike, Location, Type, Value, ValueLike},
ir::{
attribute::FlatSymbolRefAttribute, operation::OperationBuilder, r#type::IntegerType,
Attribute, Block, BlockLike, Identifier, Location, Type, Value, ValueLike,
},
Context,
};
use num_traits::Signed;
Expand Down Expand Up @@ -638,17 +641,33 @@ fn build_gate_evaluation<'ctx, 'this>(
let integer_type = rhs_value.r#type();

// Apply egcd to find gcd and inverse
let egcd_result_block = build_euclidean_algorithm(
let euclidean_result =
call_euclidean_func(context, block, location, rhs_value, circuit_modulus);
let gcd = block.extract_value(
context,
block,
location,
helper,
rhs_value,
circuit_modulus,
euclidean_result,
integer_type,
0,
)?;
let gcd = egcd_result_block.arg(0)?;
let inverse = egcd_result_block.arg(1)?;
block = egcd_result_block;
let inverse = block.extract_value(
context,
location,
euclidean_result,
integer_type,
1,
)?;
// let egcd_result_block = build_euclidean_algorithm(
// context,
// block,
// location,
// helper,
// rhs_value,
// circuit_modulus,
// )?;
// let gcd = egcd_result_block.arg(0)?;
// let inverse = egcd_result_block.arg(1)?;
// block = egcd_result_block;

// if the gcd is not 1, then fail (a and b are not coprimes)
let one = block.const_int_from_type(context, location, 1, integer_type)?;
Expand Down Expand Up @@ -719,6 +738,39 @@ fn build_gate_evaluation<'ctx, 'this>(
Ok(([ok_block, err_block], evaluated_gates))
}

fn call_euclidean_func<'ctx>(
context: &'ctx Context,
block: &'ctx Block<'ctx>,
location: Location<'ctx>,
a: Value<'ctx, '_>,
b: Value<'ctx, '_>,
) -> Value<'ctx, 'ctx> {
let integer_type: Type = IntegerType::new(context, 384 * 2).into();
let return_type = llvm::r#type::r#struct(context, &[integer_type, integer_type], false);
block
.append_operation(
OperationBuilder::new("llvm.call", location)
.add_attributes(&[
(
Identifier::new(context, "callee"),
FlatSymbolRefAttribute::new(context, "cairo_native__euclidean_algorithm")
.into(),
),
(
Identifier::new(context, "no_inline"),
Attribute::unit(context),
),
])
.add_operands(&[a, b])
.add_results(&[return_type])
.build()
.unwrap(),
)
.result(0)
.unwrap()
.into()
}

/// Generate MLIR operations for the `circuit_failure_guarantee_verify` libfunc.
/// NOOP
#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -1037,87 +1089,6 @@ fn u384_integer_to_struct<'a>(
)?)
}

/// The extended euclidean algorithm calculates the greatest common divisor (gcd) of two integers a and b,
/// as well as the bezout coefficients x and y such that ax+by=gcd(a,b)
/// if gcd(a,b) = 1, then x is the modular multiplicative inverse of a modulo b.
/// See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
///
/// Given two numbers a, b. It returns a block with gcd(a, b) and the bezout coefficient x.
fn build_euclidean_algorithm<'ctx, 'this>(
context: &'ctx Context,
block: &'ctx Block<'ctx>,
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
a: Value<'ctx, 'ctx>,
b: Value<'ctx, 'ctx>,
) -> Result<&'this Block<'ctx>> {
let integer_type = a.r#type();

let loop_block = helper.append_block(Block::new(&[
(integer_type, location),
(integer_type, location),
(integer_type, location),
(integer_type, location),
]));
let end_block = helper.append_block(Block::new(&[
(integer_type, location),
(integer_type, location),
]));

// The algorithm egcd works by calculating a series of remainders, each the remainder of dividing the previous two
// For the initial setup, r0 = b, r1 = a
// This order is chosen because if we reverse them, then the first iteration will just swap them
let prev_remainder = b;
let remainder = a;
// Similarly we'll calculate another series which starts 0,1,... and from which we will retrieve the modular inverse of a
let prev_inverse = block.const_int_from_type(context, location, 0, integer_type)?;
let inverse = block.const_int_from_type(context, location, 1, integer_type)?;
block.append_operation(cf::br(
loop_block,
&[prev_remainder, remainder, prev_inverse, inverse],
location,
));

// -- Loop body --
// Arguments are rem_(i-1), rem, inv_(i-1), inv
let prev_remainder = loop_block.arg(0)?;
let remainder = loop_block.arg(1)?;
let prev_inverse = loop_block.arg(2)?;
let inverse = loop_block.arg(3)?;

// First calculate q = rem_(i-1)/rem_i, rounded down
let quotient =
loop_block.append_op_result(arith::divui(prev_remainder, remainder, location))?;

// Then r_(i+1) = r_(i-1) - q * r_i, and inv_(i+1) = inv_(i-1) - q * inv_i
let rem_times_quo = loop_block.muli(remainder, quotient, location)?;
let inv_times_quo = loop_block.muli(inverse, quotient, location)?;
let next_remainder =
loop_block.append_op_result(arith::subi(prev_remainder, rem_times_quo, location))?;
let next_inverse =
loop_block.append_op_result(arith::subi(prev_inverse, inv_times_quo, location))?;

// Check if r_(i+1) is 0
// If true, then:
// - r_i is the gcd of a and b
// - inv_i is the bezout coefficient x

let zero = loop_block.const_int_from_type(context, location, 0, integer_type)?;
let next_remainder_eq_zero =
loop_block.cmpi(context, CmpiPredicate::Eq, next_remainder, zero, location)?;
loop_block.append_operation(cf::cond_br(
context,
next_remainder_eq_zero,
end_block,
loop_block,
&[remainder, inverse],
&[remainder, next_remainder, inverse, next_inverse],
location,
));

Ok(end_block)
}

/// Extracts values from indexes `from` - `to` (exclusive) and builds a new value of type `result_type`
///
/// Can be used with arrays, or structs with multiple elements of a single type.
Expand Down
Loading