diff --git a/Cargo.lock b/Cargo.lock index fef30482c..886230509 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5346,6 +5346,7 @@ dependencies = [ "hyle-client-sdk", "hyle-contract-sdk", "hyle-hyllar", + "rand 0.8.5", "risc0-zkvm", "serde", "sha2", diff --git a/crates/contracts/staking/Cargo.toml b/crates/contracts/staking/Cargo.toml index 4cb313801..15fbe9706 100644 --- a/crates/contracts/staking/Cargo.toml +++ b/crates/contracts/staking/Cargo.toml @@ -24,6 +24,7 @@ borsh = { workspace = true, features = ["derive"] } risc0-zkvm = { workspace = true, optional = true, features = ['std'] } client-sdk = { workspace = true, features = ["risc0"], optional = true } +rand.workspace = true [dev-dependencies] risc0-zkvm = { workspace = true, features = ['std', 'prove'] } diff --git a/crates/contracts/staking/src/state.rs b/crates/contracts/staking/src/state.rs index c95bfbe20..9a7c672d7 100644 --- a/crates/contracts/staking/src/state.rs +++ b/crates/contracts/staking/src/state.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; use anyhow::Result; use borsh::{BorshDeserialize, BorshSerialize}; +use rand::seq::SliceRandom; use sdk::{info, BlockHeight, Identity, LaneBytesSize, LaneId, ValidatorPublicKey}; use serde::{Deserialize, Serialize}; @@ -26,6 +27,36 @@ pub struct Staking { /// Minimal stake necessary to be part of consensus pub const MIN_STAKE: u128 = 32; +#[derive(Debug, PartialEq, Eq)] +pub enum CertificateReliability { + None, // _ < f + 1 + Weak, // f + 1 <= _ < 2f + 1 + Reliable, // 2f + 1 <= _ < 3f + 1 + Full, // 3f + 1 <= _ +} + +// Implémentation de Ord et PartialOrd +impl PartialOrd for CertificateReliability { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for CertificateReliability { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + use CertificateReliability::*; + let rank = |r: &CertificateReliability| -> u8 { + match r { + None => 0, + Weak => 1, + Reliable => 2, + Full => 3, + } + }; + rank(self).cmp(&rank(other)) + } +} + impl Staking { pub fn new() -> Self { Staking { @@ -73,6 +104,55 @@ impl Staking { } } + /// Returns a random list of validators to add to present validators, to form a weak quorum + pub fn choose_weak_quorum<'a, R>( + &'a self, + present_pubkeys: Vec<&'a ValidatorPublicKey>, + rng: &mut R, + ) -> Result> + where + R: rand::Rng + ?Sized, + { + let mut validators: Vec<&ValidatorPublicKey> = self + .bonded() + .iter() + .filter(|v| !present_pubkeys.contains(v)) + .collect(); + + let mut res: Vec<&ValidatorPublicKey> = vec![]; + + validators.shuffle(rng); + + let mut power: u128 = present_pubkeys + .iter() + .filter_map(|pp| self.get_stake(pp)) + .sum(); + let f = self.compute_f(); + + while power < f + 1 && !validators.is_empty() { + let random = validators.remove(0); + power += self.get_stake(random).unwrap_or(0); + res.push(random); + } + + Ok(res) + } + + pub fn check_reliability(&self, validators: &[ValidatorPublicKey]) -> CertificateReliability { + let f = self.compute_f(); + let power = self.compute_voting_power(validators); + + if power < f + 1 { + return CertificateReliability::None; + } else if power < 2 * f + 1 { + return CertificateReliability::Weak; + } else if power < 3 * f + 1 { + return CertificateReliability::Reliable; + } else { + return CertificateReliability::Full; + } + } + /// Compute f value pub fn compute_f(&self) -> u128 { self.total_bond().div_euclid(3) diff --git a/src/mempool.rs b/src/mempool.rs index 12ac00b94..7b322e9e1 100644 --- a/src/mempool.rs +++ b/src/mempool.rs @@ -574,6 +574,21 @@ impl Mempool { Ok(()) } + fn broadcast_weak(&mut self, net_message: MempoolNetMessage) -> Result<()> { + let own_key = self.crypto.validator_pubkey(); + let selected: HashSet = self + .staking + .choose_weak_quorum(vec![own_key], &mut rand::thread_rng()) + .context("Choosing validators for a weak certificate")? + .into_iter() + .map(|s| s.clone()) + .collect(); + + _ = self.broadcast_only_for_net_message(selected, net_message)?; + + Ok(()) + } + #[inline(always)] fn broadcast_net_message(&mut self, net_message: MempoolNetMessage) -> Result<()> { let enum_variant_name: &'static str = (&net_message).into(); diff --git a/src/mempool/own_lane.rs b/src/mempool/own_lane.rs index f9f606460..357bb0f48 100644 --- a/src/mempool/own_lane.rs +++ b/src/mempool/own_lane.rs @@ -6,6 +6,7 @@ use crate::{bus::BusClientSender, model::*}; use anyhow::{bail, Context, Result}; use client_sdk::tcp_client::TcpServerMessage; use futures::StreamExt; +use staking::state::CertificateReliability; use std::collections::HashSet; use tracing::{debug, trace}; @@ -159,8 +160,16 @@ impl super::Mempool { return Ok(true); }; - self.rebroadcast_data_proposal(&metadata, &dp_hash) - .context("Rebroadcasting oldest DataProposal") + if self + .staking + .check_reliability(metadata.validators().as_slice()) + < CertificateReliability::Weak + { + self.rebroadcast_data_proposal(&metadata, &dp_hash) + .context("Rebroadcasting oldest DataProposal") + } else { + Ok(false) + } } /// Rebroadcast DataProposal to validators that have not signed it yet. @@ -191,26 +200,26 @@ impl super::Mempool { self.metrics .dp_disseminations .add(self.staking.bonded().len() as u64, &[]); - self.broadcast_net_message(MempoolNetMessage::DataProposal( + self.broadcast_weak(MempoolNetMessage::DataProposal( data_proposal.hashed(), data_proposal.clone(), ))?; } else { // If None, rebroadcast it to every validator that has not yet signed it - let validator_that_has_signed: HashSet<&ValidatorPublicKey> = entry_metadata + let signators: HashSet<&ValidatorPublicKey> = entry_metadata .signatures .iter() .map(|s| &s.signature.validator) .collect(); + let signators: Vec<&ValidatorPublicKey> = signators.into_iter().collect(); // No PoA means we rebroadcast the DataProposal for non present voters - let only_for: HashSet = self - .staking - .bonded() - .iter() - .filter(|pubkey| !validator_that_has_signed.contains(pubkey)) - .cloned() - .collect(); + let only_for: HashSet = HashSet::from_iter( + self.staking + .choose_weak_quorum(signators, &mut rand::thread_rng())? + .into_iter() + .cloned(), + ); if only_for.is_empty() { return Ok(false); @@ -615,7 +624,7 @@ pub mod test { ctx.process_new_data_proposal(dp)?; ctx.timer_tick().await?; - let data_proposal = match ctx.assert_broadcast("DataProposal").await.msg { + let data_proposal = match ctx.assert_broadcast_only_for("DataProposal").await.1.msg { MempoolNetMessage::DataProposal(_, dp) => dp, _ => panic!("Expected DataProposal message"), }; @@ -767,8 +776,12 @@ pub mod test { // Récupère les deux DataProposals broadcastées par ctx1 let mut dps = vec![]; for _ in 0..2 { - match ctx1.assert_broadcast("DataProposal").await.msg { - MempoolNetMessage::DataProposal(hash, dp) => dps.push((hash, dp)), + let (set, msg) = ctx1.assert_broadcast_only_for("DataProposal").await; + match msg.msg { + MempoolNetMessage::DataProposal(hash, dp) => { + assert_eq!(set.len(), 1); + dps.push((hash, dp)); + } _ => panic!("Expected DataProposal message"), } } @@ -786,8 +799,12 @@ pub mod test { // Récupère les deux DataProposals broadcastées par ctx1 let mut dps = vec![]; for _ in 0..1 { - match ctx1.assert_broadcast("DataProposal").await.msg { - MempoolNetMessage::DataProposal(hash, dp) => dps.push((hash, dp)), + let (set, msg) = ctx1.assert_broadcast_only_for("DataProposal").await; + match msg.msg { + MempoolNetMessage::DataProposal(hash, dp) => { + assert_eq!(set.len(), 1); + dps.push((hash, dp)); + } _ => panic!("Expected DataProposal message"), } } @@ -801,8 +818,12 @@ pub mod test { // Récupère les deux DataProposals broadcastées par ctx1 let mut dps = vec![]; for _ in 0..1 { - match ctx1.assert_broadcast("DataProposal").await.msg { - MempoolNetMessage::DataProposal(hash, dp) => dps.push((hash, dp)), + let (set, msg) = ctx1.assert_broadcast_only_for("DataProposal").await; + match msg.msg { + MempoolNetMessage::DataProposal(hash, dp) => { + assert_eq!(set.len(), 1); + dps.push((hash, dp)); + } _ => panic!("Expected DataProposal message"), } } diff --git a/src/mempool/storage.rs b/src/mempool/storage.rs index 503adc230..62a732aa4 100644 --- a/src/mempool/storage.rs +++ b/src/mempool/storage.rs @@ -43,6 +43,15 @@ pub struct LaneEntryMetadata { pub signatures: Vec, } +impl LaneEntryMetadata { + pub fn validators(&self) -> Vec { + self.signatures + .iter() + .map(|s| s.signature.validator.clone()) + .collect() + } +} + pub trait Storage { fn persist(&self) -> Result<()>; diff --git a/src/mempool/tests/mod.rs b/src/mempool/tests/mod.rs index 926c492c4..abf87b42b 100644 --- a/src/mempool/tests/mod.rs +++ b/src/mempool/tests/mod.rs @@ -169,36 +169,6 @@ impl MempoolTestCtx { .expect("fail to handle event"); } - #[track_caller] - pub fn assert_broadcast_only_for( - &mut self, - description: &str, - ) -> MsgWithHeader { - #[allow(clippy::expect_fun_call)] - let rec = self - .out_receiver - .try_recv() - .expect(format!("{description}: No message broadcasted").as_str()); - - match rec { - OutboundMessage::BroadcastMessageOnlyFor(_, net_msg) => { - if let NetMessage::MempoolMessage(msg) = net_msg { - msg - } else { - println!( - "{description}: Mempool OutboundMessage message is missing, found {net_msg}" - ); - self.assert_broadcast_only_for(description) - } - } - _ => { - println!( - "{description}: Broadcast OutboundMessage message is missing, found {rec:?}", - ); - self.assert_broadcast_only_for(description) - } - } - } pub fn assert_send( &mut self, to: &ValidatorPublicKey, @@ -282,6 +252,44 @@ impl MempoolTestCtx { }) } + pub fn assert_broadcast_only_for( + &mut self, + description: &str, + ) -> Pin< + Box< + dyn Future< + Output = ( + HashSet, + MsgWithHeader, + ), + > + '_, + >, + > { + let description = description.to_string().clone(); + Box::pin(async move { + #[allow(clippy::expect_fun_call)] + let rec = tokio::time::timeout(Duration::from_millis(1000), self.out_receiver.recv()) + .await + .expect(format!("{description}: No message broadcasted only for").as_str()) + .expect(format!("{description}: No message broadcasted only for").as_str()); + + match rec { + OutboundMessage::BroadcastMessageOnlyFor(validators, net_msg) => { + if let NetMessage::MempoolMessage(msg) = net_msg { + (validators, msg) + } else { + println!("{description}: Mempool OutboundMessage message is missing, found {net_msg}"); + self.assert_broadcast_only_for(description.as_str()).await + } + } + _ => { + println!("{description}: BroadcastOnlyFor OutboundMessage message is missing, found {rec:?}"); + self.assert_broadcast_only_for(description.as_str()).await + } + } + }) + } + pub async fn handle_msg(&mut self, msg: &MsgWithHeader, _err: &str) { debug!("📥 {} Handling message: {:?}", self.name, msg); self.mempool diff --git a/src/tests/autobahn_testing.rs b/src/tests/autobahn_testing.rs index e3eab5c6c..bfa02d7a4 100644 --- a/src/tests/autobahn_testing.rs +++ b/src/tests/autobahn_testing.rs @@ -52,6 +52,31 @@ macro_rules! broadcast { }; } +macro_rules! broadcast_only_for { + (description: $description:literal, from: $sender:expr, to: [$($node:expr),*]$(, message_matches: $pattern:pat $(=> $asserts:block)? )?) => { + { + // Construct the broadcast message with sender information + let (set, message) = $sender.assert_broadcast_only_for(format!("[broadcast from: {}] {}", stringify!($sender), $description).as_str()).await; + + $({ + let msg_variant_name: &'static str = message.msg.clone().into(); + if let $pattern = (&set, &message.msg) { + $($asserts)? + } else { + panic!("[broadcast only for from: {}] {}: Message {} did not match {}", stringify!($sender), $description, msg_variant_name, stringify!($pattern)); + } + })? + + // Distribute the message to each specified node + $( + $node.handle_msg(&message, (format!("[handling broadcast message from: {} at: {}] {}", stringify!($sender), stringify!($node), $description).as_str())).await; + )* + + message + } + }; +} + macro_rules! send { ( description: $description:literal, @@ -180,10 +205,12 @@ macro_rules! disseminate { .unwrap(); $owner.timer_tick().await.unwrap(); - let dp_msg = broadcast! { + let dp_msg = broadcast_only_for! { description: "Disseminate DataProposal", - from: $owner, to: [$($voter),+], - message_matches: MempoolNetMessage::DataProposal(_, _) + from: $owner, to: [$(&mut $voter),+], + message_matches: (set, MempoolNetMessage::DataProposal(_, _)) => { + assert_eq!(set.len(), vec![$(&$voter),+].len().div_euclid(3)); + } }; join_all( @@ -411,10 +438,11 @@ async fn autobahn_basic_flow() { .unwrap(); node1.mempool_ctx.timer_tick().await.unwrap(); - broadcast! { + broadcast_only_for! { description: "Disseminate Tx", from: node1.mempool_ctx, to: [node2.mempool_ctx, node3.mempool_ctx, node4.mempool_ctx], - message_matches: MempoolNetMessage::DataProposal(_, data) => { + message_matches: (set, MempoolNetMessage::DataProposal(_, data)) => { + assert_eq!(set.len(), 1); assert_eq!(data.txs.len(), 2); } }; @@ -547,10 +575,10 @@ async fn mempool_broadcast_multiple_data_proposals() { .unwrap(); node1.mempool_ctx.timer_tick().await.unwrap(); - broadcast! { + broadcast_only_for! { description: "Disseminate Tx", from: node1.mempool_ctx, to: [node2.mempool_ctx, node3.mempool_ctx, node4.mempool_ctx], - message_matches: MempoolNetMessage::DataProposal(_, _) + message_matches: (_set, MempoolNetMessage::DataProposal(_, _)) }; join_all( @@ -588,10 +616,10 @@ async fn mempool_broadcast_multiple_data_proposals() { .unwrap(); node1.mempool_ctx.timer_tick().await.unwrap(); - broadcast! { + broadcast_only_for! { description: "Disseminate Tx", from: node1.mempool_ctx, to: [node2.mempool_ctx, node3.mempool_ctx, node4.mempool_ctx], - message_matches: MempoolNetMessage::DataProposal(_, _) + message_matches: (_set, MempoolNetMessage::DataProposal(_, _)) }; join_all( @@ -628,10 +656,10 @@ async fn mempool_podaupdate_too_early() { .unwrap(); node1.mempool_ctx.timer_tick().await.unwrap(); - let dp_msg = broadcast! { + let dp_msg = broadcast_only_for! { description: "Disseminate Tx", from: node1.mempool_ctx, to: [node2.mempool_ctx, node3.mempool_ctx], - message_matches: MempoolNetMessage::DataProposal(_, _) + message_matches: (_set, MempoolNetMessage::DataProposal(_, _)) }; join_all( @@ -707,7 +735,7 @@ async fn mempool_podaupdate_too_early() { }; broadcast! { - description: "Disseminate Tx", + description: "Disseminate Poda", from: node1.mempool_ctx, to: [node2.mempool_ctx, node3.mempool_ctx, node4.mempool_ctx], message_matches: MempoolNetMessage::PoDAUpdate(hash, signatures) => { assert_eq!(hash, &dp.hashed()); @@ -897,10 +925,11 @@ async fn mempool_fail_to_vote_on_fork() { let dp1_check; - broadcast! { + broadcast_only_for! { description: "Disseminate Tx", from: node1.mempool_ctx, to: [node2.mempool_ctx, node3.mempool_ctx, node4.mempool_ctx], - message_matches: MempoolNetMessage::DataProposal(_, data) => { + message_matches: (_set, MempoolNetMessage::DataProposal(_, data)) => { + assert_eq!(_set.len(), 1); dp1_check = data.clone(); } }; @@ -925,8 +954,8 @@ async fn mempool_fail_to_vote_on_fork() { }; node1.mempool_ctx.assert_broadcast("poda update f+1").await; - node1.mempool_ctx.assert_broadcast("poda update 2f+1").await; - node1.mempool_ctx.assert_broadcast("poda update 3f+1").await; + // node1.mempool_ctx.assert_broadcast("poda update 2f+1").await; + // node1.mempool_ctx.assert_broadcast("poda update 3f+1").await; // Second data proposal @@ -942,10 +971,10 @@ async fn mempool_fail_to_vote_on_fork() { .unwrap(); node1.mempool_ctx.timer_tick().await.unwrap(); - broadcast! { + broadcast_only_for! { description: "Disseminate Tx", from: node1.mempool_ctx, to: [node2.mempool_ctx, node3.mempool_ctx, node4.mempool_ctx], - message_matches: MempoolNetMessage::DataProposal(_, _) + message_matches: (_set, MempoolNetMessage::DataProposal(_, _)) }; join_all( @@ -2318,10 +2347,10 @@ async fn follower_commits_cut_then_mempool_sends_stale_lane() { node1.mempool_ctx.timer_tick().await.unwrap(); // Disseminate to node2 - broadcast! { + broadcast_only_for! { description: "Disseminate Tx", from: node1.mempool_ctx, to: [node2.mempool_ctx], - message_matches: MempoolNetMessage::DataProposal(_, _) + message_matches: (_set, MempoolNetMessage::DataProposal(_, _)) }; node2.mempool_ctx.handle_processed_data_proposals().await; send! {