From dcabfdba7dfe94bf7d69bd929bf1ab9eaee99102 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sun, 29 Sep 2024 23:25:53 +0530 Subject: [PATCH 01/30] feat: acknowledge notification --- rumqttc/src/client.rs | 203 +++++++++++++++++++++++++-------------- rumqttc/src/eventloop.rs | 22 ++--- rumqttc/src/lib.rs | 20 ++++ rumqttc/src/state.rs | 101 ++++++++++++------- 4 files changed, 225 insertions(+), 121 deletions(-) diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index a0c907049..47b9da0c8 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -3,7 +3,10 @@ use std::time::Duration; use crate::mqttbytes::{v4::*, QoS}; -use crate::{valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::{ + valid_filter, valid_topic, AckPromise, ConnectionError, Event, EventLoop, MqttOptions, + PromiseTx, Request, +}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -20,15 +23,15 @@ pub enum ClientError { TryRequest(Request), } -impl From> for ClientError { - fn from(e: SendError) -> Self { - Self::Request(e.into_inner()) +impl From)>> for ClientError { + fn from(e: SendError<(Request, Option)>) -> Self { + Self::Request(e.into_inner().0) } } -impl From> for ClientError { - fn from(e: TrySendError) -> Self { - Self::TryRequest(e.into_inner()) +impl From)>> for ClientError { + fn from(e: TrySendError<(Request, Option)>) -> Self { + Self::TryRequest(e.into_inner().0) } } @@ -41,7 +44,7 @@ impl From> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender, + request_tx: Sender<(Request, Option)>, } impl AsyncClient { @@ -61,7 +64,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender) -> AsyncClient { + pub fn from_senders(request_tx: Sender<(Request, Option)>) -> AsyncClient { AsyncClient { request_tx } } @@ -72,11 +75,12 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { + let (promise_tx, promise) = PromiseTx::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; @@ -84,8 +88,11 @@ impl AsyncClient { if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.request_tx.send_async(publish).await?; - Ok(()) + self.request_tx + .send_async((publish, Some(promise_tx))) + .await?; + + Ok(promise) } /// Attempts to send a MQTT Publish to the `EventLoop`. @@ -95,11 +102,12 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { + let (promise_tx, promise) = PromiseTx::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; @@ -107,8 +115,9 @@ impl AsyncClient { if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } - self.request_tx.try_send(publish)?; - Ok(()) + self.request_tx.try_send((publish, Some(promise_tx)))?; + + Ok(promise) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -116,8 +125,9 @@ impl AsyncClient { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.send_async(ack).await?; + self.request_tx.send_async((ack, None)).await?; } + Ok(()) } @@ -125,8 +135,9 @@ impl AsyncClient { pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.try_send(ack)?; + self.request_tx.try_send((ack, None))?; } + Ok(()) } @@ -137,93 +148,123 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { + let (promise_tx, promise) = PromiseTx::new(); let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; let publish = Request::Publish(publish); - self.request_tx.send_async(publish).await?; - Ok(()) + self.request_tx + .send_async((publish, Some(promise_tx))) + .await?; + + Ok(promise) } /// Sends a MQTT Subscribe to the `EventLoop` - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub async fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let subscribe = Subscribe::new(topic, qos); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::Request(subscribe.into())); } + self.request_tx + .send_async((subscribe.into(), Some(promise_tx))) + .await?; - self.request_tx.send_async(subscribe.into()).await?; - Ok(()) + Ok(promise) } /// Attempts to send a MQTT Subscribe to the `EventLoop` - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let subscribe = Subscribe::new(topic, qos); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::TryRequest(subscribe.into())); } + self.request_tx + .try_send((subscribe.into(), Some(promise_tx)))?; - self.request_tx.try_send(subscribe.into())?; - Ok(()) + Ok(promise) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub async fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { + let (promise_tx, promise) = PromiseTx::new(); let subscribe = Subscribe::new_many(topics); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::Request(subscribe.into())); } + self.request_tx + .send_async((subscribe.into(), Some(promise_tx))) + .await?; - self.request_tx.send_async(subscribe.into()).await?; - Ok(()) + Ok(promise) } /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { + let (promise_tx, promise) = PromiseTx::new(); let subscribe = Subscribe::new_many(topics); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::TryRequest(subscribe.into())); } - self.request_tx.try_send(subscribe.into())?; - Ok(()) + self.request_tx + .try_send((subscribe.into(), Some(promise_tx)))?; + + Ok(promise) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub async fn unsubscribe>(&self, topic: S) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); - self.request_tx.send_async(request).await?; - Ok(()) + self.request_tx + .send_async((unsubscribe.into(), Some(promise_tx))) + .await?; + + Ok(promise) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>(&self, topic: S) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); - self.request_tx.try_send(request)?; - Ok(()) + self.request_tx + .try_send((unsubscribe.into(), Some(promise_tx)))?; + + Ok(promise) } /// Sends a MQTT disconnect to the `EventLoop` pub async fn disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect(Disconnect); - self.request_tx.send_async(request).await?; + self.request_tx.send_async((request, None)).await?; + Ok(()) } /// Attempts to send a MQTT disconnect to the `EventLoop` pub fn try_disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect(Disconnect); - self.request_tx.try_send(request)?; + self.request_tx.try_send((request, None))?; + Ok(()) } } @@ -272,7 +313,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender) -> Client { + pub fn from_sender(request_tx: Sender<(Request, Option)>) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -285,11 +326,12 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { + let (promise_tx, promise) = PromiseTx::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; @@ -297,8 +339,9 @@ impl Client { if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.client.request_tx.send(publish)?; - Ok(()) + self.client.request_tx.send((publish, Some(promise_tx)))?; + + Ok(promise) } pub fn try_publish( @@ -307,13 +350,12 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result where S: Into, V: Into>, { - self.client.try_publish(topic, qos, retain, payload)?; - Ok(()) + self.client.try_publish(topic, qos, retain, payload) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. @@ -321,49 +363,62 @@ impl Client { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.client.request_tx.send(ack)?; + self.client.request_tx.send((ack, None))?; } + Ok(()) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - self.client.try_ack(publish)?; - Ok(()) + self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the `EventLoop` - pub fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let subscribe = Subscribe::new(topic, qos); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::Request(subscribe.into())); } + self.client + .request_tx + .send((subscribe.into(), Some(promise_tx)))?; - self.client.request_tx.send(subscribe.into())?; - Ok(()) + Ok(promise) } /// Sends a MQTT Subscribe to the `EventLoop` - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - self.client.try_subscribe(topic, qos)?; - Ok(()) + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { + self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { + let (promise_tx, promise) = PromiseTx::new(); let subscribe = Subscribe::new_many(topics); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::Request(subscribe.into())); } + self.client + .request_tx + .send((subscribe.into(), Some(promise_tx)))?; - self.client.request_tx.send(subscribe.into())?; - Ok(()) + Ok(promise) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -371,30 +426,32 @@ impl Client { } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn unsubscribe>(&self, topic: S) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe); - self.client.request_tx.send(request)?; - Ok(()) + self.client.request_tx.send((request, Some(promise_tx)))?; + + Ok(promise) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - self.client.try_unsubscribe(topic)?; - Ok(()) + pub fn try_unsubscribe>(&self, topic: S) -> Result { + self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result<(), ClientError> { + pub fn disconnect(&self) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let request = Request::Disconnect(Disconnect); - self.client.request_tx.send(request)?; - Ok(()) + self.client.request_tx.send((request, Some(promise_tx)))?; + + Ok(promise) } /// Sends a MQTT disconnect to the `EventLoop` pub fn try_disconnect(&self) -> Result<(), ClientError> { - self.client.try_disconnect()?; - Ok(()) + self.client.try_disconnect() } } diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index d31690d99..b98a390b2 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -1,5 +1,5 @@ use crate::{framed::Network, Transport}; -use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError}; +use crate::{Incoming, MqttState, NetworkOptions, Packet, PromiseTx, Request, StateError}; use crate::{MqttOptions, Outgoing}; use crate::framed::AsyncReadWrite; @@ -75,11 +75,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver, + requests_rx: Receiver<(Request, Option)>, /// Requests handle to send requests - pub(crate) requests_tx: Sender, + pub(crate) requests_tx: Sender<(Request, Option)>, /// Pending packets from last session - pub pending: VecDeque, + pub pending: VecDeque<(Request, Option)>, /// Network connection to the broker pub network: Option, /// Keep alive time @@ -132,7 +132,7 @@ impl EventLoop { // drain requests from channel which weren't yet received let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect(); - requests_in_channel.retain(|request| { + requests_in_channel.retain(|(request, _)| { match request { Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, @@ -241,8 +241,8 @@ impl EventLoop { &self.requests_rx, self.mqtt_options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok(request) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + Ok((request, tx)) => { + if let Some(outgoing) = self.state.handle_outgoing_packet(request, tx)? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -260,7 +260,7 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq), None)? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -282,10 +282,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque, - rx: &Receiver, + pending: &mut VecDeque<(Request, Option)>, + rx: &Receiver<(Request, Option)>, pending_throttle: Duration, - ) -> Result { + ) -> Result<(Request, Option), ConnectionError> { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .pop_front() AFTER sleep() otherwise we would have diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 07694ffaf..a38915323 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -145,6 +145,7 @@ use rustls_native_certs::load_native_certs; pub use state::{MqttState, StateError}; #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] pub use tls::Error as TlsError; +use tokio::sync::oneshot; #[cfg(feature = "use-native-tls")] pub use tokio_native_tls; #[cfg(feature = "use-native-tls")] @@ -222,6 +223,25 @@ impl From for Request { } } +pub type AckPromise = oneshot::Receiver<()>; + +#[derive(Debug)] +pub struct PromiseTx { + inner: oneshot::Sender<()>, +} + +impl PromiseTx { + fn new() -> (PromiseTx, AckPromise) { + let (inner, promise) = oneshot::channel(); + + (PromiseTx { inner }, promise) + } + + fn resolve(self) { + self.inner.send(()).unwrap() + } +} + /// Transport methods. Defaults to TCP. #[derive(Clone)] pub enum Transport { diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index f7cb34841..5bfc2cc89 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,4 +1,4 @@ -use crate::{Event, Incoming, Outgoing, Request}; +use crate::{Event, Incoming, Outgoing, PromiseTx, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; @@ -40,7 +40,7 @@ pub enum StateError { // This is done for 2 reasons // Bad acks or out of order acks aren't O(n) causing cpu spikes // Any missing acks from the broker are detected during the next recycled use of packet ids -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct MqttState { /// Status of last ping pub await_pingresp: bool, @@ -67,11 +67,13 @@ pub struct MqttState { /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, /// Last collision due to broker not acking in order - pub collision: Option, + pub collision: Option<(Publish, Option)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, + /// Waiters for publish/subscribe/unsubscribe acknowledgements + pub ack_waiter: Vec>, } impl MqttState { @@ -96,11 +98,12 @@ impl MqttState { // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), manual_acks, + ack_waiter: (0..max_inflight + 1).map(|_| None).collect(), } } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec { + pub fn clean(&mut self) -> Vec<(Request, Option)> { let mut pending = Vec::with_capacity(100); let (first_half, second_half) = self .outgoing_pub @@ -108,15 +111,17 @@ impl MqttState { for publish in second_half.iter_mut().chain(first_half) { if let Some(publish) = publish.take() { + let tx = self.ack_waiter.remove(publish.pkid as usize); let request = Request::Publish(publish); - pending.push(request); + pending.push((request, tx)); } } // remove and collect pending releases for pkid in self.outgoing_rel.ones() { + let tx = self.ack_waiter.remove(pkid); let request = Request::PubRel(PubRel::new(pkid as u16)); - pending.push(request); + pending.push((request, tx)); } self.outgoing_rel.clear(); @@ -138,12 +143,13 @@ impl MqttState { pub fn handle_outgoing_packet( &mut self, request: Request, + tx: Option, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::Publish(publish) => self.outgoing_publish(publish, tx)?, Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, tx)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, Request::PingReq(_) => self.outgoing_ping()?, Request::Disconnect(_) => self.outgoing_disconnect()?, Request::PubAck(puback) => self.outgoing_puback(puback)?, @@ -220,26 +226,32 @@ impl MqttState { } fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { - let publish = self + let p = self .outgoing_pub .get_mut(puback.pkid as usize) .ok_or(StateError::Unsolicited(puback.pkid))?; self.last_puback = puback.pkid; - if publish.take().is_none() { + if p.take().is_none() { error!("Unsolicited puback packet: {:?}", puback.pkid); return Err(StateError::Unsolicited(puback.pkid)); } + if let Some(tx) = self.ack_waiter.remove(puback.pkid as usize) { + // Resolve promise for QoS 1 + tx.resolve(); + } + self.inflight -= 1; - let packet = self.check_collision(puback.pkid).map(|publish| { + let packet = self.check_collision(puback.pkid).map(|(publish, tx)| { self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); self.inflight += 1; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; + self.ack_waiter[puback.pkid as usize] = tx; Packet::Publish(publish) }); @@ -248,12 +260,13 @@ impl MqttState { } fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { - let publish = self + if self .outgoing_pub .get_mut(pubrec.pkid as usize) - .ok_or(StateError::Unsolicited(pubrec.pkid))?; - - if publish.take().is_none() { + .ok_or(StateError::Unsolicited(pubrec.pkid))? + .take() + .is_none() + { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); return Err(StateError::Unsolicited(pubrec.pkid)); } @@ -287,12 +300,17 @@ impl MqttState { return Err(StateError::Unsolicited(pubcomp.pkid)); } - self.outgoing_rel.set(pubcomp.pkid as usize, false); + if let Some(tx) = self.ack_waiter.remove(pubcomp.pkid as usize) { + // Resolve promise for QoS 2 + tx.resolve(); + } + self.inflight -= 1; - let packet = self.check_collision(pubcomp.pkid).map(|publish| { + let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; + self.ack_waiter[pubcomp.pkid as usize] = tx; Packet::Publish(publish) }); @@ -308,7 +326,11 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { + fn outgoing_publish( + &mut self, + mut publish: Publish, + tx: Option, + ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -322,7 +344,7 @@ impl MqttState { .is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some(publish); + self.collision = Some((publish, tx)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -343,6 +365,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); + self.ack_waiter[publish.pkid as usize] = tx; Ok(Some(Packet::Publish(publish))) } @@ -409,6 +432,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, + tx: Option, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -424,6 +448,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); self.events.push_back(event); + self.ack_waiter[subscription.pkid as usize] = tx; Ok(Some(Packet::Subscribe(subscription))) } @@ -431,6 +456,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, + tx: Option, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -442,6 +468,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); self.events.push_back(event); + self.ack_waiter[unsub.pkid as usize] = tx; Ok(Some(Packet::Unsubscribe(unsub))) } @@ -455,8 +482,8 @@ impl MqttState { Ok(Some(Packet::Disconnect)) } - fn check_collision(&mut self, pkid: u16) -> Option { - if let Some(publish) = &self.collision { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option)> { + if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); } @@ -555,7 +582,7 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -563,12 +590,12 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -576,12 +603,12 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -667,8 +694,8 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1).unwrap(); - mqtt.outgoing_publish(publish2).unwrap(); + mqtt.outgoing_publish(publish1, None).unwrap(); + mqtt.outgoing_publish(publish2, None).unwrap(); assert_eq!(mqtt.inflight, 2); mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); @@ -700,8 +727,8 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1); - let _publish_out = mqtt.outgoing_publish(publish2); + let _publish_out = mqtt.outgoing_publish(publish1, None); + let _publish_out = mqtt.outgoing_publish(publish2, None); mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); assert_eq!(mqtt.inflight, 2); @@ -719,7 +746,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - let packet = mqtt.outgoing_publish(publish).unwrap().unwrap(); + let packet = mqtt.outgoing_publish(publish, None).unwrap().unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -761,7 +788,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); @@ -775,7 +802,7 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish)) + mqtt.handle_outgoing_packet(Request::Publish(publish), None) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) .unwrap(); @@ -849,7 +876,7 @@ mod test { let requests = mqtt.clean(); let res = vec![6, 1, 2, 3]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = req { + if let Request::Publish(publish) = &req.0 { assert_eq!(publish.pkid, idx); } else { unreachable!() @@ -861,7 +888,7 @@ mod test { let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = req { + if let Request::Publish(publish) = &req.0 { assert_eq!(publish.pkid, idx); } else { unreachable!() @@ -873,7 +900,7 @@ mod test { let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = req { + if let Request::Publish(publish) = &req.0 { assert_eq!(publish.pkid, idx); } else { unreachable!() From bc8de881446577123b992f8aa548e1bd41166591 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sun, 29 Sep 2024 23:36:25 +0530 Subject: [PATCH 02/30] fix: don't panic if promise dropped --- rumqttc/src/lib.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index a38915323..30081214d 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -238,7 +238,9 @@ impl PromiseTx { } fn resolve(self) { - self.inner.send(()).unwrap() + if self.inner.send(()).is_err() { + trace!("Promise was drpped") + } } } From 176fe5c37ede8f9a687f66b1c905aaf71eb933cf Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sun, 29 Sep 2024 23:44:51 +0530 Subject: [PATCH 03/30] fix: don't remove slots, just take contents --- rumqttc/src/state.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 5bfc2cc89..4e5953730 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -111,7 +111,7 @@ impl MqttState { for publish in second_half.iter_mut().chain(first_half) { if let Some(publish) = publish.take() { - let tx = self.ack_waiter.remove(publish.pkid as usize); + let tx = self.ack_waiter[publish.pkid as usize].take(); let request = Request::Publish(publish); pending.push((request, tx)); } @@ -119,7 +119,7 @@ impl MqttState { // remove and collect pending releases for pkid in self.outgoing_rel.ones() { - let tx = self.ack_waiter.remove(pkid); + let tx = self.ack_waiter[pkid].take(); let request = Request::PubRel(PubRel::new(pkid as u16)); pending.push((request, tx)); } @@ -238,7 +238,7 @@ impl MqttState { return Err(StateError::Unsolicited(puback.pkid)); } - if let Some(tx) = self.ack_waiter.remove(puback.pkid as usize) { + if let Some(tx) = self.ack_waiter[puback.pkid as usize].take() { // Resolve promise for QoS 1 tx.resolve(); } @@ -300,7 +300,7 @@ impl MqttState { return Err(StateError::Unsolicited(pubcomp.pkid)); } - if let Some(tx) = self.ack_waiter.remove(pubcomp.pkid as usize) { + if let Some(tx) = self.ack_waiter[pubcomp.pkid as usize].take() { // Resolve promise for QoS 2 tx.resolve(); } From c4ce2f74c41d83726e9ccbe00c77403e2e2168ac Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sun, 29 Sep 2024 23:55:01 +0530 Subject: [PATCH 04/30] fix: direct resolve for QoS 0 --- rumqttc/src/state.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 4e5953730..7c24f46ca 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -365,7 +365,10 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); - self.ack_waiter[publish.pkid as usize] = tx; + match (publish.qos, tx) { + (QoS::AtMostOnce, Some(tx)) => tx.resolve(), + (_, tx) => self.ack_waiter[publish.pkid as usize] = tx, + } Ok(Some(Packet::Publish(publish))) } From b6c5ed3cee2e8c2d7d72b891195f92f360d55def Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 01:21:19 +0530 Subject: [PATCH 05/30] fix: validate and notify sub/unsub acks --- rumqttc/src/state.rs | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 7c24f46ca..2193a3035 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -174,8 +174,8 @@ impl MqttState { let outgoing = match &packet { Incoming::PingResp => self.handle_incoming_pingresp()?, Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, - Incoming::SubAck(_suback) => self.handle_incoming_suback()?, - Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback()?, + Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, + Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback)?, Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, @@ -190,11 +190,44 @@ impl MqttState { Ok(outgoing) } - fn handle_incoming_suback(&mut self) -> Result, StateError> { + fn is_pkid_of_publish(&self, pkid: u16) -> bool { + self.outgoing_pub[pkid as usize].is_some() || self.outgoing_rel.contains(pkid as usize) + } + + fn handle_incoming_suback( + &mut self, + SubAck { pkid, return_codes }: &SubAck, + ) -> Result, StateError> { + // Expected ack for a subscribe packet, not a publish packet + if self.is_pkid_of_publish(*pkid) { + return Err(StateError::Unsolicited(*pkid)); + } + + if return_codes + .iter() + .any(|x| matches!(x, SubscribeReasonCode::Success(_))) + { + if let Some(tx) = self.ack_waiter[*pkid as usize].take() { + tx.resolve(); + } + } + Ok(None) } - fn handle_incoming_unsuback(&mut self) -> Result, StateError> { + fn handle_incoming_unsuback( + &mut self, + UnsubAck { pkid }: &UnsubAck, + ) -> Result, StateError> { + // Expected ack for a unsubscribe packet, not a publish packet + if self.is_pkid_of_publish(*pkid) { + return Err(StateError::Unsolicited(*pkid)); + } + + if let Some(tx) = self.ack_waiter[*pkid as usize].take() { + tx.resolve(); + } + Ok(None) } From ab96189d5fc1aa5539966e566731ff259c797e24 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 01:23:01 +0530 Subject: [PATCH 06/30] fix: forget acked packet --- rumqttc/src/state.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 2193a3035..3ea124e35 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -338,6 +338,7 @@ impl MqttState { tx.resolve(); } + self.outgoing_rel.set(pubcomp.pkid as usize, false); self.inflight -= 1; let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); From 656cf742d24d0c22ac7d68b56e6ad2b7f0bedaa8 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 01:33:17 +0530 Subject: [PATCH 07/30] refactor --- rumqttc/src/state.rs | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 3ea124e35..a7bc16fa7 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -194,20 +194,18 @@ impl MqttState { self.outgoing_pub[pkid as usize].is_some() || self.outgoing_rel.contains(pkid as usize) } - fn handle_incoming_suback( - &mut self, - SubAck { pkid, return_codes }: &SubAck, - ) -> Result, StateError> { + fn handle_incoming_suback(&mut self, suback: &SubAck) -> Result, StateError> { // Expected ack for a subscribe packet, not a publish packet - if self.is_pkid_of_publish(*pkid) { - return Err(StateError::Unsolicited(*pkid)); + if self.is_pkid_of_publish(suback.pkid) { + return Err(StateError::Unsolicited(suback.pkid)); } - if return_codes + if suback + .return_codes .iter() .any(|x| matches!(x, SubscribeReasonCode::Success(_))) { - if let Some(tx) = self.ack_waiter[*pkid as usize].take() { + if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { tx.resolve(); } } @@ -217,14 +215,14 @@ impl MqttState { fn handle_incoming_unsuback( &mut self, - UnsubAck { pkid }: &UnsubAck, + unsuback: &UnsubAck, ) -> Result, StateError> { // Expected ack for a unsubscribe packet, not a publish packet - if self.is_pkid_of_publish(*pkid) { - return Err(StateError::Unsolicited(*pkid)); + if self.is_pkid_of_publish(unsuback.pkid) { + return Err(StateError::Unsolicited(unsuback.pkid)); } - if let Some(tx) = self.ack_waiter[*pkid as usize].take() { + if let Some(tx) = self.ack_waiter[unsuback.pkid as usize].take() { tx.resolve(); } From 4a936fedde6c360fa65124f3e27e129d8d9532aa Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 01:36:06 +0530 Subject: [PATCH 08/30] fix: panic if `max_inflight == u16::MAX` --- rumqttc/src/state.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index a7bc16fa7..81a60c904 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -98,7 +98,7 @@ impl MqttState { // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), manual_acks, - ack_waiter: (0..max_inflight + 1).map(|_| None).collect(), + ack_waiter: (0..max_inflight as usize + 1).map(|_| None).collect(), } } From d50859506b4b97041a69d33d8e43757a658ba295 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 01:36:45 +0530 Subject: [PATCH 09/30] feat: notify acks in v5 also --- rumqttc/src/v5/client.rs | 253 ++++++++++++++++++++++-------------- rumqttc/src/v5/eventloop.rs | 21 +-- rumqttc/src/v5/state.rs | 157 ++++++++++++++++------ 3 files changed, 278 insertions(+), 153 deletions(-) diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 4913d1d0f..a099f54d7 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -8,7 +8,7 @@ use super::mqttbytes::v5::{ }; use super::mqttbytes::QoS; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; -use crate::{valid_filter, valid_topic}; +use crate::{valid_filter, valid_topic, AckPromise, PromiseTx}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -25,19 +25,19 @@ pub enum ClientError { TryRequest(Request), } -impl From> for ClientError { - fn from(e: SendError) -> Self { - Self::Request(e.into_inner()) +impl From)>> for ClientError { + fn from(e: SendError<(Request, Option)>) -> Self { + Self::Request(e.into_inner().0) } } -impl From> for ClientError { - fn from(e: TrySendError) -> Self { - Self::TryRequest(e.into_inner()) +impl From)>> for ClientError { + fn from(e: TrySendError<(Request, Option)>) -> Self { + Self::TryRequest(e.into_inner().0) } } -/// An asynchronous client, communicates with MQTT `EventLoop`. +// An asynchronous client, communicates with MQTT `EventLoop`. /// /// This is cloneable and can be used to asynchronously [`publish`](`AsyncClient::publish`), /// [`subscribe`](`AsyncClient::subscribe`) through the `EventLoop`, which is to be polled parallelly. @@ -46,7 +46,7 @@ impl From> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender, + request_tx: Sender<(Request, Option)>, } impl AsyncClient { @@ -66,7 +66,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender) -> AsyncClient { + pub fn from_senders(request_tx: Sender<(Request, Option)>) -> AsyncClient { AsyncClient { request_tx } } @@ -78,11 +78,12 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, { + let (promise_tx, promise) = PromiseTx::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; @@ -90,8 +91,11 @@ impl AsyncClient { if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.request_tx.send_async(publish).await?; - Ok(()) + self.request_tx + .send_async((publish, Some(promise_tx))) + .await?; + + Ok(promise) } pub async fn publish_with_properties( @@ -101,7 +105,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -116,7 +120,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -132,11 +136,12 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, { + let (promise_tx, promise) = PromiseTx::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; @@ -144,8 +149,9 @@ impl AsyncClient { if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } - self.request_tx.try_send(publish)?; - Ok(()) + self.request_tx.try_send((publish, Some(promise_tx)))?; + + Ok(promise) } pub fn try_publish_with_properties( @@ -155,7 +161,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -169,7 +175,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -182,8 +188,9 @@ impl AsyncClient { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.send_async(ack).await?; + self.request_tx.send_async((ack, None)).await?; } + Ok(()) } @@ -191,8 +198,9 @@ impl AsyncClient { pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.request_tx.try_send(ack)?; + self.request_tx.try_send((ack, None))?; } + Ok(()) } @@ -204,19 +212,20 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { + let (promise_tx, promise) = PromiseTx::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; let publish = Request::Publish(publish); - if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); - } - self.request_tx.send_async(publish).await?; - Ok(()) + self.request_tx + .send_async((publish, Some(promise_tx))) + .await?; + + Ok(promise) } pub async fn publish_bytes_with_properties( @@ -226,7 +235,7 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { @@ -240,7 +249,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result<(), ClientError> + ) -> Result where S: Into, { @@ -254,15 +263,18 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::Request(subscribe.into())); } + self.request_tx + .send_async((subscribe.into(), Some(promise_tx))) + .await?; - self.request_tx.send_async(subscribe.into()).await?; - Ok(()) + Ok(promise) } pub async fn subscribe_with_properties>( @@ -270,11 +282,15 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_subscribe(topic, qos, Some(properties)).await } - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub async fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.handle_subscribe(topic, qos, None).await } @@ -284,15 +300,17 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::TryRequest(subscribe.into())); } + self.request_tx + .try_send((subscribe.into(), Some(promise_tx)))?; - self.request_tx.try_send(subscribe.into())?; - Ok(()) + Ok(promise) } pub fn try_subscribe_with_properties>( @@ -300,11 +318,15 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_try_subscribe(topic, qos, Some(properties)) } - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.handle_try_subscribe(topic, qos, None) } @@ -313,32 +335,34 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { + let (promise_tx, promise) = PromiseTx::new(); let subscribe = Subscribe::new_many(topics, properties); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::Request(subscribe.into())); } + self.request_tx + .send_async((subscribe.into(), Some(promise_tx))) + .await?; - self.request_tx.send_async(subscribe.into()).await?; - - Ok(()) + Ok(promise) } pub async fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)).await } - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub async fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -350,31 +374,33 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { + let (promise_tx, promise) = PromiseTx::new(); let subscribe = Subscribe::new_many(topics, properties); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::TryRequest(subscribe.into())); } + self.request_tx + .try_send((subscribe.into(), Some(promise_tx)))?; - self.request_tx.try_send(subscribe.into())?; - Ok(()) + Ok(promise) } pub fn try_subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { self.handle_try_subscribe_many(topics, Some(properties)) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -386,22 +412,26 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let unsubscribe = Unsubscribe::new(topic, properties); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.send_async(request).await?; - Ok(()) + self.request_tx + .send_async((request, Some(promise_tx))) + .await?; + + Ok(promise) } pub async fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_unsubscribe(topic, Some(properties)).await } - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub async fn unsubscribe>(&self, topic: S) -> Result { self.handle_unsubscribe(topic, None).await } @@ -410,36 +440,40 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let unsubscribe = Unsubscribe::new(topic, properties); let request = Request::Unsubscribe(unsubscribe); - self.request_tx.try_send(request)?; - Ok(()) + self.request_tx.try_send((request, Some(promise_tx)))?; + + Ok(promise) } pub fn try_unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_try_unsubscribe(topic, Some(properties)) } - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>(&self, topic: S) -> Result { self.handle_try_unsubscribe(topic, None) } /// Sends a MQTT disconnect to the `EventLoop` pub async fn disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect; - self.request_tx.send_async(request).await?; + self.request_tx.send_async((request, None)).await?; + Ok(()) } /// Attempts to send a MQTT disconnect to the `EventLoop` pub fn try_disconnect(&self) -> Result<(), ClientError> { let request = Request::Disconnect; - self.request_tx.try_send(request)?; + self.request_tx.try_send((request, None))?; + Ok(()) } } @@ -489,7 +523,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender) -> Client { + pub fn from_sender(request_tx: Sender<(Request, Option)>) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -503,11 +537,12 @@ impl Client { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, { + let (promise_tx, promise) = PromiseTx::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; @@ -515,8 +550,9 @@ impl Client { if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.client.request_tx.send(publish)?; - Ok(()) + self.client.request_tx.send((publish, Some(promise_tx)))?; + + Ok(promise) } pub fn publish_with_properties( @@ -526,7 +562,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -540,7 +576,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -555,7 +591,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -570,7 +606,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result where S: Into, P: Into, @@ -583,15 +619,15 @@ impl Client { let ack = get_ack_req(publish); if let Some(ack) = ack { - self.client.request_tx.send(ack)?; + self.client.request_tx.send((ack, None))?; } + Ok(()) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - self.client.try_ack(publish)?; - Ok(()) + self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -600,15 +636,18 @@ impl Client { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::Request(subscribe.into())); } + self.client + .request_tx + .send((subscribe.into(), Some(promise_tx)))?; - self.client.request_tx.send(subscribe.into())?; - Ok(()) + Ok(promise) } pub fn subscribe_with_properties>( @@ -616,11 +655,15 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_subscribe(topic, qos, Some(properties)) } - pub fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.handle_subscribe(topic, qos, None) } @@ -630,12 +673,16 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.client .try_subscribe_with_properties(topic, qos, properties) } - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result { self.client.try_subscribe(topic, qos) } @@ -644,31 +691,34 @@ impl Client { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { + let (promise_tx, promise) = PromiseTx::new(); let subscribe = Subscribe::new_many(topics, properties); if !subscribe_has_valid_filters(&subscribe) { return Err(ClientError::Request(subscribe.into())); } + self.client + .request_tx + .send((subscribe.into(), Some(promise_tx)))?; - self.client.request_tx.send(subscribe.into())?; - Ok(()) + Ok(promise) } pub fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)) } - pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -679,7 +729,7 @@ impl Client { &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result where T: IntoIterator, { @@ -687,7 +737,7 @@ impl Client { .try_subscribe_many_with_properties(topics, properties) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result where T: IntoIterator, { @@ -699,22 +749,24 @@ impl Client { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let unsubscribe = Unsubscribe::new(topic, properties); let request = Request::Unsubscribe(unsubscribe); - self.client.request_tx.send(request)?; - Ok(()) + self.client.request_tx.send((request, Some(promise_tx)))?; + + Ok(promise) } pub fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.handle_unsubscribe(topic, Some(properties)) } - pub fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn unsubscribe>(&self, topic: S) -> Result { self.handle_unsubscribe(topic, None) } @@ -723,26 +775,27 @@ impl Client { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result { self.client .try_unsubscribe_with_properties(topic, properties) } - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>(&self, topic: S) -> Result { self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result<(), ClientError> { + pub fn disconnect(&self) -> Result { + let (promise_tx, promise) = PromiseTx::new(); let request = Request::Disconnect; - self.client.request_tx.send(request)?; - Ok(()) + self.client.request_tx.send((request, Some(promise_tx)))?; + + Ok(promise) } /// Sends a MQTT disconnect to the `EventLoop` pub fn try_disconnect(&self) -> Result<(), ClientError> { - self.client.try_disconnect()?; - Ok(()) + self.client.try_disconnect() } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index cd0568ada..b2c2fc506 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -3,6 +3,7 @@ use super::mqttbytes::v5::*; use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport}; use crate::eventloop::socket_connect; use crate::framed::AsyncReadWrite; +use crate::PromiseTx; use flume::{bounded, Receiver, Sender}; use tokio::select; @@ -73,11 +74,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver, + requests_rx: Receiver<(Request, Option)>, /// Requests handle to send requests - pub(crate) requests_tx: Sender, + pub(crate) requests_tx: Sender<(Request, Option)>, /// Pending packets from last session - pub pending: VecDeque, + pub pending: VecDeque<(Request, Option)>, /// Network connection to the broker network: Option, /// Keep alive time @@ -128,7 +129,7 @@ impl EventLoop { // drain requests from channel which weren't yet received let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect(); - requests_in_channel.retain(|request| { + requests_in_channel.retain(|(request, _)| { match request { Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, @@ -223,8 +224,8 @@ impl EventLoop { &self.requests_rx, self.options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok(request) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + Ok((request, tx)) => { + if let Some(outgoing) = self.state.handle_outgoing_packet(request, tx)? { network.write(outgoing).await?; } network.flush().await?; @@ -245,7 +246,7 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq, None)? { network.write(outgoing).await?; } network.flush().await?; @@ -255,10 +256,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque, - rx: &Receiver, + pending: &mut VecDeque<(Request, Option)>, + rx: &Receiver<(Request, Option)>, pending_throttle: Duration, - ) -> Result { + ) -> Result<(Request, Option), ConnectionError> { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .next() AFTER sleep() otherwise .next() would diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 0f08a33b8..5d7ebdab3 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,3 +1,5 @@ +use crate::PromiseTx; + use super::mqttbytes::v5::{ ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, @@ -74,7 +76,7 @@ impl From for StateError { // This is done for 2 reasons // Bad acks or out of order acks aren't O(n) causing cpu spikes // Any missing acks from the broker are detected during the next recycled use of packet ids -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct MqttState { /// Status of last ping pub await_pingresp: bool, @@ -97,7 +99,7 @@ pub struct MqttState { /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, /// Last collision due to broker not acking in order - pub collision: Option, + pub collision: Option<(Publish, Option)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -110,6 +112,8 @@ pub struct MqttState { pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, + /// Waiters for publish/subscribe/unsubscribe acknowledgements + pub ack_waiter: Vec>, } impl MqttState { @@ -137,24 +141,27 @@ impl MqttState { broker_topic_alias_max: 0, max_outgoing_inflight: max_inflight, max_outgoing_inflight_upper_limit: max_inflight, + ack_waiter: (0..max_inflight as usize + 1).map(|_| None).collect(), } } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec { + pub fn clean(&mut self) -> Vec<(Request, Option)> { let mut pending = Vec::with_capacity(100); // remove and collect pending publishes for publish in self.outgoing_pub.iter_mut() { if let Some(publish) = publish.take() { + let tx = self.ack_waiter[publish.pkid as usize].take(); let request = Request::Publish(publish); - pending.push(request); + pending.push((request, tx)); } } // remove and collect pending releases for pkid in self.outgoing_rel.ones() { + let tx = self.ack_waiter[pkid].take(); let request = Request::PubRel(PubRel::new(pkid as u16, None)); - pending.push(request); + pending.push((request, tx)); } self.outgoing_rel.clear(); @@ -176,12 +183,13 @@ impl MqttState { pub fn handle_outgoing_packet( &mut self, request: Request, + tx: Option, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::Publish(publish) => self.outgoing_publish(publish, tx)?, Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, tx)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, Request::PingReq => self.outgoing_ping()?, Request::Disconnect => { self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? @@ -231,10 +239,19 @@ impl MqttState { self.outgoing_disconnect(DisconnectReasonCode::ProtocolError) } + fn is_pkid_of_publish(&self, pkid: u16) -> bool { + self.outgoing_pub[pkid as usize].is_some() || self.outgoing_rel.contains(pkid as usize) + } + fn handle_incoming_suback( &mut self, suback: &mut SubAck, ) -> Result, StateError> { + // Expected ack for a subscribe packet, not a publish packet + if self.is_pkid_of_publish(suback.pkid) { + return Err(StateError::Unsolicited(suback.pkid)); + } + for reason in suback.return_codes.iter() { match reason { SubscribeReasonCode::Success(qos) => { @@ -242,9 +259,20 @@ impl MqttState { } _ => { warn!("SubAck Pkid = {:?}, Reason = {:?}", suback.pkid, reason); - }, + } + } + } + + if suback + .return_codes + .iter() + .any(|x| matches!(x, SubscribeReasonCode::Success(_))) + { + if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { + tx.resolve(); } } + Ok(None) } @@ -252,11 +280,23 @@ impl MqttState { &mut self, unsuback: &mut UnsubAck, ) -> Result, StateError> { + // Expected ack for a unsubscribe packet, not a publish packet + if self.is_pkid_of_publish(unsuback.pkid) { + return Err(StateError::Unsolicited(unsuback.pkid)); + } + for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { warn!("UnsubAck Pkid = {:?}, Reason = {:?}", unsuback.pkid, reason); } } + + if unsuback.reasons.contains(&UnsubAckReason::Success) { + if let Some(tx) = self.ack_waiter[unsuback.pkid as usize].take() { + tx.resolve(); + } + } + Ok(None) } @@ -359,16 +399,24 @@ impl MqttState { return Err(StateError::Unsolicited(puback.pkid)); } + if let Some(tx) = self.ack_waiter[puback.pkid as usize].take() { + // Resolve promise for QoS 1 + tx.resolve(); + } + self.inflight -= 1; if puback.reason != PubAckReason::Success && puback.reason != PubAckReason::NoMatchingSubscribers { - warn!("PubAck Pkid = {:?}, reason: {:?}", puback.pkid, puback.reason); + warn!( + "PubAck Pkid = {:?}, reason: {:?}", + puback.pkid, puback.reason + ); return Ok(None); } - if let Some(publish) = self.check_collision(puback.pkid) { + if let Some((publish, tx)) = self.check_collision(puback.pkid) { self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); self.inflight += 1; @@ -376,6 +424,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; + self.ack_waiter[puback.pkid as usize] = tx; return Ok(Some(Packet::Publish(publish))); } @@ -397,7 +446,10 @@ impl MqttState { if pubrec.reason != PubRecReason::Success && pubrec.reason != PubRecReason::NoMatchingSubscribers { - warn!("PubRec Pkid = {:?}, reason: {:?}", pubrec.pkid, pubrec.reason); + warn!( + "PubRec Pkid = {:?}, reason: {:?}", + pubrec.pkid, pubrec.reason + ); return Ok(None); } @@ -417,7 +469,10 @@ impl MqttState { self.incoming_pub.set(pubrel.pkid as usize, false); if pubrel.reason != PubRelReason::Success { - warn!("PubRel Pkid = {:?}, reason: {:?}", pubrel.pkid, pubrel.reason); + warn!( + "PubRel Pkid = {:?}, reason: {:?}", + pubrel.pkid, pubrel.reason + ); return Ok(None); } @@ -428,23 +483,31 @@ impl MqttState { } fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { - let outgoing = self.check_collision(pubcomp.pkid).map(|publish| { - let pkid = publish.pkid; - let event = Event::Outgoing(Outgoing::Publish(pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - - Packet::Publish(publish) - }); - if !self.outgoing_rel.contains(pubcomp.pkid as usize) { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); } + + if let Some(tx) = self.ack_waiter[pubcomp.pkid as usize].take() { + // Resolve promise for QoS 2 + tx.resolve(); + } + self.outgoing_rel.set(pubcomp.pkid as usize, false); + let outgoing = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + self.ack_waiter[pubcomp.pkid as usize] = tx; + + Packet::Publish(publish) + }); if pubcomp.reason != PubCompReason::Success { - warn!("PubComp Pkid = {:?}, reason: {:?}", pubcomp.pkid, pubcomp.reason); + warn!( + "PubComp Pkid = {:?}, reason: {:?}", + pubcomp.pkid, pubcomp.reason + ); return Ok(None); } @@ -459,7 +522,11 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { + fn outgoing_publish( + &mut self, + mut publish: Publish, + tx: Option, + ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -473,7 +540,7 @@ impl MqttState { .is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some(publish); + self.collision = Some((publish, tx)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -575,6 +642,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, + tx: Option, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -591,6 +659,7 @@ impl MqttState { let pkid = subscription.pkid; let event = Event::Outgoing(Outgoing::Subscribe(pkid)); self.events.push_back(event); + self.ack_waiter[subscription.pkid as usize] = tx; Ok(Some(Packet::Subscribe(subscription))) } @@ -598,6 +667,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, + tx: Option, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -610,6 +680,7 @@ impl MqttState { let pkid = unsub.pkid; let event = Event::Outgoing(Outgoing::Unsubscribe(pkid)); self.events.push_back(event); + self.ack_waiter[unsub.pkid as usize] = tx; Ok(Some(Packet::Unsubscribe(unsub))) } @@ -625,8 +696,8 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } - fn check_collision(&mut self, pkid: u16) -> Option { - if let Some(publish) = &self.collision { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option)> { + if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); } @@ -725,7 +796,7 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -733,12 +804,12 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -746,12 +817,12 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -763,17 +834,17 @@ mod test { // QoS2 publish let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be set back down to 0, since we hit the limit - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); // This should cause a collition - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 2); assert!(mqtt.collision.is_some()); @@ -783,7 +854,7 @@ mod test { assert_eq!(mqtt.inflight, 1); // Now there should be space in the outgoing queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + mqtt.outgoing_publish(publish.clone(), None).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); } @@ -867,8 +938,8 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1).unwrap(); - mqtt.outgoing_publish(publish2).unwrap(); + mqtt.outgoing_publish(publish1, None).unwrap(); + mqtt.outgoing_publish(publish2, None).unwrap(); assert_eq!(mqtt.inflight, 2); mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap(); @@ -902,8 +973,8 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1); - let _publish_out = mqtt.outgoing_publish(publish2); + let _publish_out = mqtt.outgoing_publish(publish1, None); + let _publish_out = mqtt.outgoing_publish(publish2, None); mqtt.handle_incoming_pubrec(&PubRec::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 2); @@ -921,7 +992,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - match mqtt.outgoing_publish(publish).unwrap().unwrap() { + match mqtt.outgoing_publish(publish, None).unwrap().unwrap() { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } @@ -961,7 +1032,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); + mqtt.outgoing_publish(publish, None).unwrap(); mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap(); mqtt.handle_incoming_pubcomp(&PubComp::new(1, None)) @@ -976,7 +1047,7 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish)) + mqtt.handle_outgoing_packet(Request::Publish(publish), None) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None))) .unwrap(); From 6b0c3ec9d1594fcfe92faf01ec48d6bf40804b54 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 01:38:14 +0530 Subject: [PATCH 10/30] doc: add changelog --- rumqttc/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index c8c8716a3..6469227bf 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `set_session_expiry_interval` and `session_expiry_interval` methods on `MqttOptions`. * `Auth` packet as per MQTT5 standards * Allow configuring the `nodelay` property of underlying TCP client with the `tcp_nodelay` field in `NetworkOptions` +* `publish` / `subscribe` / `unsubscribe` methods on `AsyncClient` and `Client` now return an `AckPromise` which resolves when the packet(except for QoS 0 publishes, which resolve as soon as handled) is acknowledged by the broker. ### Changed From f43ea2148c2c22245f8b1d7d3c087a1250580ebc Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 16:38:15 +0530 Subject: [PATCH 11/30] fix: bug observed in https://github.com/bytebeamio/rumqtt/pull/916#issuecomment-2381874520 --- rumqttc/src/v5/state.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 5d7ebdab3..85feb959a 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -576,6 +576,10 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); + match (publish.qos, tx) { + (QoS::AtMostOnce, Some(tx)) => tx.resolve(), + (_, tx) => self.ack_waiter[publish.pkid as usize] = tx, + } Ok(Some(Packet::Publish(publish))) } From 2632ab11e986b038832445f6dc561a765b86983e Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 16:38:46 +0530 Subject: [PATCH 12/30] doc: add examples of ack notify --- rumqttc/examples/ack_promise.rs | 90 ++++++++++++++++++++++++ rumqttc/examples/ack_promise_sync.rs | 92 +++++++++++++++++++++++++ rumqttc/examples/ack_promise_v5.rs | 90 ++++++++++++++++++++++++ rumqttc/examples/ack_promise_v5_sync.rs | 92 +++++++++++++++++++++++++ 4 files changed, 364 insertions(+) create mode 100644 rumqttc/examples/ack_promise.rs create mode 100644 rumqttc/examples/ack_promise_sync.rs create mode 100644 rumqttc/examples/ack_promise_v5.rs create mode 100644 rumqttc/examples/ack_promise_v5_sync.rs diff --git a/rumqttc/examples/ack_promise.rs b/rumqttc/examples/ack_promise.rs new file mode 100644 index 000000000..5abe4b307 --- /dev/null +++ b/rumqttc/examples/ack_promise.rs @@ -0,0 +1,90 @@ +use tokio::task::{self, JoinSet}; + +use rumqttc::{AsyncClient, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + } + } + } + }); + + // Subscribe and wait for broker acknowledgement + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap() + .await + .unwrap(); + println!("Acknowledged Subscribe"); + + // Publish at all QoS levels and wait for broker acknowledgement + client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) + .await + .unwrap() + .await + .unwrap(); + println!("Acknowledged Pub(1)"); + + client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) + .await + .unwrap() + .await + .unwrap(); + println!("Acknowledged Pub(2)"); + + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) + .await + .unwrap() + .await + .unwrap(); + println!("Acknowledged Pub(3)"); + + // Publish and spawn wait for notification + let mut set = JoinSet::new(); + + let future = client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) + .await + .unwrap(); + set.spawn(async { future.await.map(|_| 1) }); + + let future = client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) + .await + .unwrap(); + set.spawn(async { future.await.map(|_| 2) }); + + let future = client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) + .await + .unwrap(); + set.spawn(async { future.await.map(|_| 3) }); + + while let Some(res) = set.join_next().await { + println!("Acknowledged = {:?}", res?); + } + + Ok(()) +} diff --git a/rumqttc/examples/ack_promise_sync.rs b/rumqttc/examples/ack_promise_sync.rs new file mode 100644 index 000000000..55b6d2be3 --- /dev/null +++ b/rumqttc/examples/ack_promise_sync.rs @@ -0,0 +1,92 @@ +use flume::bounded; +use rumqttc::{Client, MqttOptions, QoS}; +use std::error::Error; +use std::thread; +use std::time::Duration; + +fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut conn) = Client::new(mqttoptions, 10); + thread::spawn(move || { + for event in conn.iter() { + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + } + } + } + }); + + // Subscribe and wait for broker acknowledgement + client + .subscribe("hello/world", QoS::AtMostOnce) + .unwrap() + .blocking_recv() + .unwrap(); + println!("Acknowledged Subscribe"); + + // Publish at all QoS levels and wait for broker acknowledgement + client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) + .unwrap() + .blocking_recv() + .unwrap(); + println!("Acknowledged Pub(1)"); + + client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) + .unwrap() + .blocking_recv() + .unwrap(); + println!("Acknowledged Pub(2)"); + + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) + .unwrap() + .blocking_recv() + .unwrap(); + println!("Acknowledged Pub(3)"); + + // Spawn threads for each publish, use channel to notify result + let (tx, rx) = bounded(1); + + let future = client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) + .unwrap(); + let tx_clone = tx.clone(); + thread::spawn(move || { + let res = future.blocking_recv().map(|_| 1); + tx_clone.send(res).unwrap() + }); + + let future = client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) + .unwrap(); + let tx_clone = tx.clone(); + thread::spawn(move || { + let res = future.blocking_recv().map(|_| 2); + tx_clone.send(res).unwrap() + }); + + let future = client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) + .unwrap(); + thread::spawn(move || { + let res = future.blocking_recv().map(|_| 3); + tx.send(res).unwrap() + }); + + while let Ok(res) = rx.recv() { + println!("Acknowledged = {:?}", res?); + } + + Ok(()) +} diff --git a/rumqttc/examples/ack_promise_v5.rs b/rumqttc/examples/ack_promise_v5.rs new file mode 100644 index 000000000..c6939ad2c --- /dev/null +++ b/rumqttc/examples/ack_promise_v5.rs @@ -0,0 +1,90 @@ +use tokio::task::{self, JoinSet}; + +use rumqttc::v5::{mqttbytes::QoS, AsyncClient, MqttOptions}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + } + } + } + }); + + // Subscribe and wait for broker acknowledgement + client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap() + .await + .unwrap(); + println!("Acknowledged Subscribe"); + + // Publish at all QoS levels and wait for broker acknowledgement + client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) + .await + .unwrap() + .await + .unwrap(); + println!("Acknowledged Pub(1)"); + + client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) + .await + .unwrap() + .await + .unwrap(); + println!("Acknowledged Pub(2)"); + + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) + .await + .unwrap() + .await + .unwrap(); + println!("Acknowledged Pub(3)"); + + // Publish and spawn wait for notification + let mut set = JoinSet::new(); + + let future = client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) + .await + .unwrap(); + set.spawn(async { future.await.map(|_| 1) }); + + let future = client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) + .await + .unwrap(); + set.spawn(async { future.await.map(|_| 2) }); + + let future = client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) + .await + .unwrap(); + set.spawn(async { future.await.map(|_| 3) }); + + while let Some(res) = set.join_next().await { + println!("Acknowledged = {:?}", res?); + } + + Ok(()) +} diff --git a/rumqttc/examples/ack_promise_v5_sync.rs b/rumqttc/examples/ack_promise_v5_sync.rs new file mode 100644 index 000000000..a6b3526ea --- /dev/null +++ b/rumqttc/examples/ack_promise_v5_sync.rs @@ -0,0 +1,92 @@ +use flume::bounded; +use rumqttc::v5::{mqttbytes::QoS, Client, MqttOptions}; +use std::error::Error; +use std::thread; +use std::time::Duration; + +fn main() -> Result<(), Box> { + pretty_env_logger::init(); + // color_backtrace::install(); + + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut conn) = Client::new(mqttoptions, 10); + thread::spawn(move || { + for event in conn.iter() { + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + } + } + } + }); + + // Subscribe and wait for broker acknowledgement + client + .subscribe("hello/world", QoS::AtMostOnce) + .unwrap() + .blocking_recv() + .unwrap(); + println!("Acknowledged Subscribe"); + + // Publish at all QoS levels and wait for broker acknowledgement + client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) + .unwrap() + .blocking_recv() + .unwrap(); + println!("Acknowledged Pub(1)"); + + client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) + .unwrap() + .blocking_recv() + .unwrap(); + println!("Acknowledged Pub(2)"); + + client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) + .unwrap() + .blocking_recv() + .unwrap(); + println!("Acknowledged Pub(3)"); + + // Spawn threads for each publish, use channel to notify result + let (tx, rx) = bounded(1); + + let future = client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) + .unwrap(); + let tx_clone = tx.clone(); + thread::spawn(move || { + let res = future.blocking_recv().map(|_| 1); + tx_clone.send(res).unwrap() + }); + + let future = client + .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) + .unwrap(); + let tx_clone = tx.clone(); + thread::spawn(move || { + let res = future.blocking_recv().map(|_| 2); + tx_clone.send(res).unwrap() + }); + + let future = client + .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) + .unwrap(); + thread::spawn(move || { + let res = future.blocking_recv().map(|_| 3); + tx.send(res).unwrap() + }); + + while let Ok(res) = rx.recv() { + println!("Acknowledged = {:?}", res?); + } + + Ok(()) +} From 8ad22231d17ca7e4019ec076b490084aaf8a94db Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 16:58:54 +0530 Subject: [PATCH 13/30] feat: return pkid of ack --- rumqttc/src/lib.rs | 9 +++++---- rumqttc/src/state.rs | 10 +++++----- rumqttc/src/v5/state.rs | 10 +++++----- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 30081214d..041ea68f3 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -223,11 +223,12 @@ impl From for Request { } } -pub type AckPromise = oneshot::Receiver<()>; +pub type Pkid = u16; +pub type AckPromise = oneshot::Receiver; #[derive(Debug)] pub struct PromiseTx { - inner: oneshot::Sender<()>, + inner: oneshot::Sender, } impl PromiseTx { @@ -237,8 +238,8 @@ impl PromiseTx { (PromiseTx { inner }, promise) } - fn resolve(self) { - if self.inner.send(()).is_err() { + fn resolve(self, pkid: Pkid) { + if self.inner.send(pkid).is_err() { trace!("Promise was drpped") } } diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 81a60c904..5b8c44095 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -206,7 +206,7 @@ impl MqttState { .any(|x| matches!(x, SubscribeReasonCode::Success(_))) { if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { - tx.resolve(); + tx.resolve(suback.pkid); } } @@ -223,7 +223,7 @@ impl MqttState { } if let Some(tx) = self.ack_waiter[unsuback.pkid as usize].take() { - tx.resolve(); + tx.resolve(unsuback.pkid); } Ok(None) @@ -271,7 +271,7 @@ impl MqttState { if let Some(tx) = self.ack_waiter[puback.pkid as usize].take() { // Resolve promise for QoS 1 - tx.resolve(); + tx.resolve(puback.pkid); } self.inflight -= 1; @@ -333,7 +333,7 @@ impl MqttState { if let Some(tx) = self.ack_waiter[pubcomp.pkid as usize].take() { // Resolve promise for QoS 2 - tx.resolve(); + tx.resolve(pubcomp.pkid); } self.outgoing_rel.set(pubcomp.pkid as usize, false); @@ -398,7 +398,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); match (publish.qos, tx) { - (QoS::AtMostOnce, Some(tx)) => tx.resolve(), + (QoS::AtMostOnce, Some(tx)) => tx.resolve(publish.pkid), (_, tx) => self.ack_waiter[publish.pkid as usize] = tx, } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 85feb959a..ca6d0d00e 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -269,7 +269,7 @@ impl MqttState { .any(|x| matches!(x, SubscribeReasonCode::Success(_))) { if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { - tx.resolve(); + tx.resolve(suback.pkid); } } @@ -293,7 +293,7 @@ impl MqttState { if unsuback.reasons.contains(&UnsubAckReason::Success) { if let Some(tx) = self.ack_waiter[unsuback.pkid as usize].take() { - tx.resolve(); + tx.resolve(unsuback.pkid); } } @@ -401,7 +401,7 @@ impl MqttState { if let Some(tx) = self.ack_waiter[puback.pkid as usize].take() { // Resolve promise for QoS 1 - tx.resolve(); + tx.resolve(puback.pkid); } self.inflight -= 1; @@ -490,7 +490,7 @@ impl MqttState { if let Some(tx) = self.ack_waiter[pubcomp.pkid as usize].take() { // Resolve promise for QoS 2 - tx.resolve(); + tx.resolve(pubcomp.pkid); } self.outgoing_rel.set(pubcomp.pkid as usize, false); @@ -577,7 +577,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); match (publish.qos, tx) { - (QoS::AtMostOnce, Some(tx)) => tx.resolve(), + (QoS::AtMostOnce, Some(tx)) => tx.resolve(0), (_, tx) => self.ack_waiter[publish.pkid as usize] = tx, } From db7c322b9d26953e53bb9a7e6a4c425f524f8f5b Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 18:39:24 +0530 Subject: [PATCH 14/30] doc: update example with pkids --- rumqttc/examples/ack_promise.rs | 26 ++++++++++++------------- rumqttc/examples/ack_promise_sync.rs | 26 ++++++++++++------------- rumqttc/examples/ack_promise_v5.rs | 26 ++++++++++++------------- rumqttc/examples/ack_promise_v5_sync.rs | 26 ++++++++++++------------- 4 files changed, 52 insertions(+), 52 deletions(-) diff --git a/rumqttc/examples/ack_promise.rs b/rumqttc/examples/ack_promise.rs index 5abe4b307..2933c4adb 100644 --- a/rumqttc/examples/ack_promise.rs +++ b/rumqttc/examples/ack_promise.rs @@ -28,38 +28,38 @@ async fn main() -> Result<(), Box> { }); // Subscribe and wait for broker acknowledgement - client + let pkid = client .subscribe("hello/world", QoS::AtMostOnce) .await .unwrap() .await .unwrap(); - println!("Acknowledged Subscribe"); + println!("Acknowledged Subscribe({pkid})"); // Publish at all QoS levels and wait for broker acknowledgement - client + let pkid = client .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .await .unwrap() .await .unwrap(); - println!("Acknowledged Pub(1)"); + println!("Acknowledged Pub({pkid})"); - client + let pkid = client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .await .unwrap() .await .unwrap(); - println!("Acknowledged Pub(2)"); + println!("Acknowledged Pub({pkid})"); - client + let pkid = client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .await .unwrap() .await .unwrap(); - println!("Acknowledged Pub(3)"); + println!("Acknowledged Pub({pkid})"); // Publish and spawn wait for notification let mut set = JoinSet::new(); @@ -68,22 +68,22 @@ async fn main() -> Result<(), Box> { .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .await .unwrap(); - set.spawn(async { future.await.map(|_| 1) }); + set.spawn(async { future.await }); let future = client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .await .unwrap(); - set.spawn(async { future.await.map(|_| 2) }); + set.spawn(async { future.await }); let future = client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .await .unwrap(); - set.spawn(async { future.await.map(|_| 3) }); + set.spawn(async { future.await }); - while let Some(res) = set.join_next().await { - println!("Acknowledged = {:?}", res?); + while let Some(Ok(Ok(pkid))) = set.join_next().await { + println!("Acknowledged Pub({pkid})"); } Ok(()) diff --git a/rumqttc/examples/ack_promise_sync.rs b/rumqttc/examples/ack_promise_sync.rs index 55b6d2be3..225726251 100644 --- a/rumqttc/examples/ack_promise_sync.rs +++ b/rumqttc/examples/ack_promise_sync.rs @@ -26,34 +26,34 @@ fn main() -> Result<(), Box> { }); // Subscribe and wait for broker acknowledgement - client + let pkid = client .subscribe("hello/world", QoS::AtMostOnce) .unwrap() .blocking_recv() .unwrap(); - println!("Acknowledged Subscribe"); + println!("Acknowledged Subscribe({pkid})"); // Publish at all QoS levels and wait for broker acknowledgement - client + let pkid = client .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .unwrap() .blocking_recv() .unwrap(); - println!("Acknowledged Pub(1)"); + println!("Acknowledged Pub({pkid})"); - client + let pkid = client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .unwrap() .blocking_recv() .unwrap(); - println!("Acknowledged Pub(2)"); + println!("Acknowledged Pub({pkid})"); - client + let pkid = client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .unwrap() .blocking_recv() .unwrap(); - println!("Acknowledged Pub(3)"); + println!("Acknowledged Pub({pkid})"); // Spawn threads for each publish, use channel to notify result let (tx, rx) = bounded(1); @@ -63,7 +63,7 @@ fn main() -> Result<(), Box> { .unwrap(); let tx_clone = tx.clone(); thread::spawn(move || { - let res = future.blocking_recv().map(|_| 1); + let res = future.blocking_recv(); tx_clone.send(res).unwrap() }); @@ -72,7 +72,7 @@ fn main() -> Result<(), Box> { .unwrap(); let tx_clone = tx.clone(); thread::spawn(move || { - let res = future.blocking_recv().map(|_| 2); + let res = future.blocking_recv(); tx_clone.send(res).unwrap() }); @@ -80,12 +80,12 @@ fn main() -> Result<(), Box> { .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .unwrap(); thread::spawn(move || { - let res = future.blocking_recv().map(|_| 3); + let res = future.blocking_recv(); tx.send(res).unwrap() }); - while let Ok(res) = rx.recv() { - println!("Acknowledged = {:?}", res?); + while let Ok(Ok(pkid)) = rx.recv() { + println!("Acknowledged Pub({:?})", pkid); } Ok(()) diff --git a/rumqttc/examples/ack_promise_v5.rs b/rumqttc/examples/ack_promise_v5.rs index c6939ad2c..df351cda5 100644 --- a/rumqttc/examples/ack_promise_v5.rs +++ b/rumqttc/examples/ack_promise_v5.rs @@ -28,38 +28,38 @@ async fn main() -> Result<(), Box> { }); // Subscribe and wait for broker acknowledgement - client + let pkid = client .subscribe("hello/world", QoS::AtMostOnce) .await .unwrap() .await .unwrap(); - println!("Acknowledged Subscribe"); + println!("Acknowledged Subscribe({pkid})"); // Publish at all QoS levels and wait for broker acknowledgement - client + let pkid = client .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .await .unwrap() .await .unwrap(); - println!("Acknowledged Pub(1)"); + println!("Acknowledged Pub({pkid})"); - client + let pkid = client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .await .unwrap() .await .unwrap(); - println!("Acknowledged Pub(2)"); + println!("Acknowledged Pub({pkid})"); - client + let pkid = client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .await .unwrap() .await .unwrap(); - println!("Acknowledged Pub(3)"); + println!("Acknowledged Pub({pkid})"); // Publish and spawn wait for notification let mut set = JoinSet::new(); @@ -68,22 +68,22 @@ async fn main() -> Result<(), Box> { .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .await .unwrap(); - set.spawn(async { future.await.map(|_| 1) }); + set.spawn(async { future.await }); let future = client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .await .unwrap(); - set.spawn(async { future.await.map(|_| 2) }); + set.spawn(async { future.await }); let future = client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .await .unwrap(); - set.spawn(async { future.await.map(|_| 3) }); + set.spawn(async { future.await }); - while let Some(res) = set.join_next().await { - println!("Acknowledged = {:?}", res?); + while let Some(Ok(Ok(pkid))) = set.join_next().await { + println!("Acknowledged Pub({pkid})"); } Ok(()) diff --git a/rumqttc/examples/ack_promise_v5_sync.rs b/rumqttc/examples/ack_promise_v5_sync.rs index a6b3526ea..22bee0260 100644 --- a/rumqttc/examples/ack_promise_v5_sync.rs +++ b/rumqttc/examples/ack_promise_v5_sync.rs @@ -26,34 +26,34 @@ fn main() -> Result<(), Box> { }); // Subscribe and wait for broker acknowledgement - client + let pkid = client .subscribe("hello/world", QoS::AtMostOnce) .unwrap() .blocking_recv() .unwrap(); - println!("Acknowledged Subscribe"); + println!("Acknowledged Subscribe({pkid})"); // Publish at all QoS levels and wait for broker acknowledgement - client + let pkid = client .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .unwrap() .blocking_recv() .unwrap(); - println!("Acknowledged Pub(1)"); + println!("Acknowledged Pub({pkid})"); - client + let pkid = client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .unwrap() .blocking_recv() .unwrap(); - println!("Acknowledged Pub(2)"); + println!("Acknowledged Pub({pkid})"); - client + let pkid = client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .unwrap() .blocking_recv() .unwrap(); - println!("Acknowledged Pub(3)"); + println!("Acknowledged Pub({pkid})"); // Spawn threads for each publish, use channel to notify result let (tx, rx) = bounded(1); @@ -63,7 +63,7 @@ fn main() -> Result<(), Box> { .unwrap(); let tx_clone = tx.clone(); thread::spawn(move || { - let res = future.blocking_recv().map(|_| 1); + let res = future.blocking_recv(); tx_clone.send(res).unwrap() }); @@ -72,7 +72,7 @@ fn main() -> Result<(), Box> { .unwrap(); let tx_clone = tx.clone(); thread::spawn(move || { - let res = future.blocking_recv().map(|_| 2); + let res = future.blocking_recv(); tx_clone.send(res).unwrap() }); @@ -80,12 +80,12 @@ fn main() -> Result<(), Box> { .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .unwrap(); thread::spawn(move || { - let res = future.blocking_recv().map(|_| 3); + let res = future.blocking_recv(); tx.send(res).unwrap() }); - while let Ok(res) = rx.recv() { - println!("Acknowledged = {:?}", res?); + while let Ok(Ok(pkid)) = rx.recv() { + println!("Acknowledged Pub({:?})", pkid); } Ok(()) From 31887caec96d58a49e8485faf66d32a534f12d31 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 19:32:38 +0530 Subject: [PATCH 15/30] feat: return reason of request failure --- rumqttc/examples/ack_promise.rs | 45 ++++++++++++------ rumqttc/examples/ack_promise_sync.rs | 59 ++++++++++++++--------- rumqttc/examples/ack_promise_v5.rs | 45 ++++++++++++------ rumqttc/examples/ack_promise_v5_sync.rs | 27 +++++++---- rumqttc/src/lib.rs | 62 ++++++++++++++++++++++--- rumqttc/src/state.rs | 14 +++--- rumqttc/src/v5/state.rs | 36 ++++++++++---- 7 files changed, 207 insertions(+), 81 deletions(-) diff --git a/rumqttc/examples/ack_promise.rs b/rumqttc/examples/ack_promise.rs index 2933c4adb..7366442f8 100644 --- a/rumqttc/examples/ack_promise.rs +++ b/rumqttc/examples/ack_promise.rs @@ -28,38 +28,46 @@ async fn main() -> Result<(), Box> { }); // Subscribe and wait for broker acknowledgement - let pkid = client + match client .subscribe("hello/world", QoS::AtMostOnce) .await .unwrap() .await - .unwrap(); - println!("Acknowledged Subscribe({pkid})"); + { + Ok(pkid) => println!("Acknowledged Sub({pkid})"), + Err(e) => println!("Subscription failed: {e:?}"), + } // Publish at all QoS levels and wait for broker acknowledgement - let pkid = client + match client .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .await .unwrap() .await - .unwrap(); - println!("Acknowledged Pub({pkid})"); + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } - let pkid = client + match client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .await .unwrap() .await - .unwrap(); - println!("Acknowledged Pub({pkid})"); + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } - let pkid = client + match client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .await .unwrap() .await - .unwrap(); - println!("Acknowledged Pub({pkid})"); + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } // Publish and spawn wait for notification let mut set = JoinSet::new(); @@ -82,8 +90,17 @@ async fn main() -> Result<(), Box> { .unwrap(); set.spawn(async { future.await }); - while let Some(Ok(Ok(pkid))) = set.join_next().await { - println!("Acknowledged Pub({pkid})"); + while let Some(Ok(res)) = set.join_next().await { + match res { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } + } + + // Unsubscribe and wait for broker acknowledgement + match client.unsubscribe("hello/world").await.unwrap().await { + Ok(pkid) => println!("Acknowledged Unsub({pkid})"), + Err(e) => println!("Unsubscription failed: {e:?}"), } Ok(()) diff --git a/rumqttc/examples/ack_promise_sync.rs b/rumqttc/examples/ack_promise_sync.rs index 225726251..2158d46ad 100644 --- a/rumqttc/examples/ack_promise_sync.rs +++ b/rumqttc/examples/ack_promise_sync.rs @@ -26,34 +26,42 @@ fn main() -> Result<(), Box> { }); // Subscribe and wait for broker acknowledgement - let pkid = client + match client .subscribe("hello/world", QoS::AtMostOnce) .unwrap() - .blocking_recv() - .unwrap(); - println!("Acknowledged Subscribe({pkid})"); + .blocking_wait() + { + Ok(pkid) => println!("Acknowledged Sub({pkid})"), + Err(e) => println!("Subscription failed: {e:?}"), + } // Publish at all QoS levels and wait for broker acknowledgement - let pkid = client + match client .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .unwrap() - .blocking_recv() - .unwrap(); - println!("Acknowledged Pub({pkid})"); + .blocking_wait() + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } - let pkid = client + match client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .unwrap() - .blocking_recv() - .unwrap(); - println!("Acknowledged Pub({pkid})"); + .blocking_wait() + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } - let pkid = client + match client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .unwrap() - .blocking_recv() - .unwrap(); - println!("Acknowledged Pub({pkid})"); + .blocking_wait() + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } // Spawn threads for each publish, use channel to notify result let (tx, rx) = bounded(1); @@ -63,7 +71,7 @@ fn main() -> Result<(), Box> { .unwrap(); let tx_clone = tx.clone(); thread::spawn(move || { - let res = future.blocking_recv(); + let res = future.blocking_wait(); tx_clone.send(res).unwrap() }); @@ -72,7 +80,7 @@ fn main() -> Result<(), Box> { .unwrap(); let tx_clone = tx.clone(); thread::spawn(move || { - let res = future.blocking_recv(); + let res = future.blocking_wait(); tx_clone.send(res).unwrap() }); @@ -80,12 +88,21 @@ fn main() -> Result<(), Box> { .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .unwrap(); thread::spawn(move || { - let res = future.blocking_recv(); + let res = future.blocking_wait(); tx.send(res).unwrap() }); - while let Ok(Ok(pkid)) = rx.recv() { - println!("Acknowledged Pub({:?})", pkid); + while let Ok(res) = rx.recv() { + match res { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } + } + + // Unsubscribe and wait for broker acknowledgement + match client.unsubscribe("hello/world").unwrap().blocking_wait() { + Ok(pkid) => println!("Acknowledged Unsub({pkid})"), + Err(e) => println!("Unsubscription failed: {e:?}"), } Ok(()) diff --git a/rumqttc/examples/ack_promise_v5.rs b/rumqttc/examples/ack_promise_v5.rs index df351cda5..c2eb26319 100644 --- a/rumqttc/examples/ack_promise_v5.rs +++ b/rumqttc/examples/ack_promise_v5.rs @@ -28,38 +28,46 @@ async fn main() -> Result<(), Box> { }); // Subscribe and wait for broker acknowledgement - let pkid = client + match client .subscribe("hello/world", QoS::AtMostOnce) .await .unwrap() .await - .unwrap(); - println!("Acknowledged Subscribe({pkid})"); + { + Ok(pkid) => println!("Acknowledged Sub({pkid})"), + Err(e) => println!("Subscription failed: {e:?}"), + } // Publish at all QoS levels and wait for broker acknowledgement - let pkid = client + match client .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .await .unwrap() .await - .unwrap(); - println!("Acknowledged Pub({pkid})"); + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } - let pkid = client + match client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .await .unwrap() .await - .unwrap(); - println!("Acknowledged Pub({pkid})"); + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } - let pkid = client + match client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .await .unwrap() .await - .unwrap(); - println!("Acknowledged Pub({pkid})"); + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } // Publish and spawn wait for notification let mut set = JoinSet::new(); @@ -82,8 +90,17 @@ async fn main() -> Result<(), Box> { .unwrap(); set.spawn(async { future.await }); - while let Some(Ok(Ok(pkid))) = set.join_next().await { - println!("Acknowledged Pub({pkid})"); + while let Some(Ok(res)) = set.join_next().await { + match res { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } + } + + // Unsubscribe and wait for broker acknowledgement + match client.unsubscribe("hello/world").await.unwrap().await { + Ok(pkid) => println!("Acknowledged Unsub({pkid})"), + Err(e) => println!("Unsubscription failed: {e:?}"), } Ok(()) diff --git a/rumqttc/examples/ack_promise_v5_sync.rs b/rumqttc/examples/ack_promise_v5_sync.rs index 22bee0260..26ae240dc 100644 --- a/rumqttc/examples/ack_promise_v5_sync.rs +++ b/rumqttc/examples/ack_promise_v5_sync.rs @@ -29,7 +29,7 @@ fn main() -> Result<(), Box> { let pkid = client .subscribe("hello/world", QoS::AtMostOnce) .unwrap() - .blocking_recv() + .blocking_wait() .unwrap(); println!("Acknowledged Subscribe({pkid})"); @@ -37,21 +37,21 @@ fn main() -> Result<(), Box> { let pkid = client .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) .unwrap() - .blocking_recv() + .blocking_wait() .unwrap(); println!("Acknowledged Pub({pkid})"); let pkid = client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .unwrap() - .blocking_recv() + .blocking_wait() .unwrap(); println!("Acknowledged Pub({pkid})"); let pkid = client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .unwrap() - .blocking_recv() + .blocking_wait() .unwrap(); println!("Acknowledged Pub({pkid})"); @@ -63,7 +63,7 @@ fn main() -> Result<(), Box> { .unwrap(); let tx_clone = tx.clone(); thread::spawn(move || { - let res = future.blocking_recv(); + let res = future.blocking_wait(); tx_clone.send(res).unwrap() }); @@ -72,7 +72,7 @@ fn main() -> Result<(), Box> { .unwrap(); let tx_clone = tx.clone(); thread::spawn(move || { - let res = future.blocking_recv(); + let res = future.blocking_wait(); tx_clone.send(res).unwrap() }); @@ -80,12 +80,21 @@ fn main() -> Result<(), Box> { .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .unwrap(); thread::spawn(move || { - let res = future.blocking_recv(); + let res = future.blocking_wait(); tx.send(res).unwrap() }); - while let Ok(Ok(pkid)) = rx.recv() { - println!("Acknowledged Pub({:?})", pkid); + while let Ok(res) = rx.recv() { + match res { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } + } + + // Unsubscribe and wait for broker acknowledgement + match client.unsubscribe("hello/world").unwrap().blocking_wait() { + Ok(pkid) => println!("Acknowledged Unsub({pkid})"), + Err(e) => println!("Unsubscription failed: {e:?}"), } Ok(()) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 041ea68f3..cbb1a4029 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -98,7 +98,12 @@ #[macro_use] extern crate log; -use std::fmt::{self, Debug, Formatter}; +use std::{ + fmt::{self, Debug, Formatter}, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; #[cfg(any(feature = "use-rustls", feature = "websocket"))] use std::sync::Arc; @@ -224,23 +229,66 @@ impl From for Request { } pub type Pkid = u16; -pub type AckPromise = oneshot::Receiver; + +#[derive(Debug, thiserror::Error)] +pub enum PromiseError { + #[error("Sender side of channel was dropped")] + Disconnected, + #[error("Broker rejected the request, reason: {reason}")] + Rejected { reason: String }, +} + +pub struct AckPromise { + rx: oneshot::Receiver>, +} + +impl Future for AckPromise { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let polled = unsafe { self.map_unchecked_mut(|s| &mut s.rx) }.poll(cx); + + match polled { + Poll::Ready(Ok(p)) => Poll::Ready(p), + Poll::Ready(Err(_)) => Poll::Ready(Err(PromiseError::Disconnected)), + Poll::Pending => Poll::Pending, + } + } +} + +impl AckPromise { + pub fn blocking_wait(self) -> Result { + self.rx + .blocking_recv() + .map_err(|_| PromiseError::Disconnected)? + } +} #[derive(Debug)] pub struct PromiseTx { - inner: oneshot::Sender, + tx: oneshot::Sender>, } impl PromiseTx { fn new() -> (PromiseTx, AckPromise) { - let (inner, promise) = oneshot::channel(); + let (tx, rx) = oneshot::channel(); - (PromiseTx { inner }, promise) + (PromiseTx { tx }, AckPromise { rx }) } fn resolve(self, pkid: Pkid) { - if self.inner.send(pkid).is_err() { - trace!("Promise was drpped") + if self.tx.send(Ok(pkid)).is_err() { + trace!("Promise was dropped") + } + } + + fn fail(self, reason: String) { + if self + .tx + .send(Err(PromiseError::Rejected { reason })) + .is_err() + { + trace!("Promise was dropped") } } } diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 5b8c44095..36755afe7 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -200,13 +200,15 @@ impl MqttState { return Err(StateError::Unsolicited(suback.pkid)); } - if suback - .return_codes - .iter() - .any(|x| matches!(x, SubscribeReasonCode::Success(_))) - { - if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { + if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { + if suback + .return_codes + .iter() + .all(|r| matches!(r, SubscribeReasonCode::Success(_))) + { tx.resolve(suback.pkid); + } else { + tx.fail(format!("{:?}", suback.return_codes)); } } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index ca6d0d00e..68b9016e8 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -263,13 +263,15 @@ impl MqttState { } } - if suback - .return_codes - .iter() - .any(|x| matches!(x, SubscribeReasonCode::Success(_))) - { - if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { + if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { + if suback + .return_codes + .iter() + .all(|r| matches!(r, SubscribeReasonCode::Success(_))) + { tx.resolve(suback.pkid); + } else { + tx.fail(format!("{:?}", suback.return_codes)); } } @@ -291,9 +293,15 @@ impl MqttState { } } - if unsuback.reasons.contains(&UnsubAckReason::Success) { - if let Some(tx) = self.ack_waiter[unsuback.pkid as usize].take() { + if let Some(tx) = self.ack_waiter[unsuback.pkid as usize].take() { + if unsuback + .reasons + .iter() + .all(|r| matches!(r, UnsubAckReason::Success)) + { tx.resolve(unsuback.pkid); + } else { + tx.fail(format!("{:?}", unsuback.reasons)); } } @@ -401,7 +409,11 @@ impl MqttState { if let Some(tx) = self.ack_waiter[puback.pkid as usize].take() { // Resolve promise for QoS 1 - tx.resolve(puback.pkid); + if puback.reason == PubAckReason::Success { + tx.resolve(puback.pkid); + } else { + tx.fail(format!("{:?}", puback.reason)); + } } self.inflight -= 1; @@ -490,7 +502,11 @@ impl MqttState { if let Some(tx) = self.ack_waiter[pubcomp.pkid as usize].take() { // Resolve promise for QoS 2 - tx.resolve(pubcomp.pkid); + if pubcomp.reason == PubCompReason::Success { + tx.resolve(pubcomp.pkid); + } else { + tx.fail(format!("{:?}", pubcomp.reason)); + } } self.outgoing_rel.set(pubcomp.pkid as usize, false); From 066783af548179c6c8d8cfc73db54d86776fe879 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 30 Sep 2024 23:13:19 +0530 Subject: [PATCH 16/30] rm unnecessary example --- rumqttc/examples/ack_promise_v5_sync.rs | 101 ------------------------ 1 file changed, 101 deletions(-) delete mode 100644 rumqttc/examples/ack_promise_v5_sync.rs diff --git a/rumqttc/examples/ack_promise_v5_sync.rs b/rumqttc/examples/ack_promise_v5_sync.rs deleted file mode 100644 index 26ae240dc..000000000 --- a/rumqttc/examples/ack_promise_v5_sync.rs +++ /dev/null @@ -1,101 +0,0 @@ -use flume::bounded; -use rumqttc::v5::{mqttbytes::QoS, Client, MqttOptions}; -use std::error::Error; -use std::thread; -use std::time::Duration; - -fn main() -> Result<(), Box> { - pretty_env_logger::init(); - // color_backtrace::install(); - - let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); - mqttoptions.set_keep_alive(Duration::from_secs(5)); - - let (client, mut conn) = Client::new(mqttoptions, 10); - thread::spawn(move || { - for event in conn.iter() { - match &event { - Ok(v) => { - println!("Event = {v:?}"); - } - Err(e) => { - println!("Error = {e:?}"); - } - } - } - }); - - // Subscribe and wait for broker acknowledgement - let pkid = client - .subscribe("hello/world", QoS::AtMostOnce) - .unwrap() - .blocking_wait() - .unwrap(); - println!("Acknowledged Subscribe({pkid})"); - - // Publish at all QoS levels and wait for broker acknowledgement - let pkid = client - .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) - .unwrap() - .blocking_wait() - .unwrap(); - println!("Acknowledged Pub({pkid})"); - - let pkid = client - .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) - .unwrap() - .blocking_wait() - .unwrap(); - println!("Acknowledged Pub({pkid})"); - - let pkid = client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) - .unwrap() - .blocking_wait() - .unwrap(); - println!("Acknowledged Pub({pkid})"); - - // Spawn threads for each publish, use channel to notify result - let (tx, rx) = bounded(1); - - let future = client - .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) - .unwrap(); - let tx_clone = tx.clone(); - thread::spawn(move || { - let res = future.blocking_wait(); - tx_clone.send(res).unwrap() - }); - - let future = client - .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) - .unwrap(); - let tx_clone = tx.clone(); - thread::spawn(move || { - let res = future.blocking_wait(); - tx_clone.send(res).unwrap() - }); - - let future = client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) - .unwrap(); - thread::spawn(move || { - let res = future.blocking_wait(); - tx.send(res).unwrap() - }); - - while let Ok(res) = rx.recv() { - match res { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), - Err(e) => println!("Publish failed: {e:?}"), - } - } - - // Unsubscribe and wait for broker acknowledgement - match client.unsubscribe("hello/world").unwrap().blocking_wait() { - Ok(pkid) => println!("Acknowledged Unsub({pkid})"), - Err(e) => println!("Unsubscription failed: {e:?}"), - } - - Ok(()) -} From 10d843c98a8f590637e0026c27a4c55352d9e724 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 1 Oct 2024 22:28:22 +0530 Subject: [PATCH 17/30] test: reliability of ack promises --- rumqttc/tests/broker.rs | 12 +++ rumqttc/tests/reliability.rs | 197 ++++++++++++++++++++++++++++++++++- 2 files changed, 208 insertions(+), 1 deletion(-) diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index 760a2ab37..c450b68b4 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -122,6 +122,18 @@ impl Broker { self.framed.write(packet).await.unwrap(); } + /// Sends a publish record + pub async fn pubrec(&mut self, pkid: u16) { + let packet = Packet::PubRec(PubRec::new(pkid)); + self.framed.write(packet).await.unwrap(); + } + + /// Sends a publish complete + pub async fn pubcomp(&mut self, pkid: u16) { + let packet = Packet::PubComp(PubComp::new(pkid)); + self.framed.write(packet).await.unwrap(); + } + /// Sends an acknowledgement pub async fn pingresp(&mut self) { let packet = Packet::PingResp; diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 633ca4706..9933dfc14 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -1,6 +1,9 @@ use matches::assert_matches; use std::time::{Duration, Instant}; -use tokio::{task, time}; +use tokio::{ + task, + time::{self, timeout}, +}; mod broker; @@ -585,3 +588,195 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly }); handle.await.unwrap(); } + +#[tokio::test] +async fn resolve_on_qos0_before_write_to_tcp_buffer() { + let options = MqttOptions::new("dummy", "127.0.0.1", 3004); + let (client, mut eventloop) = AsyncClient::new(options, 5); + + task::spawn(async move { + let res = run(&mut eventloop, false).await; + if let Err(e) = res { + match e { + ConnectionError::FlushTimeout => { + assert!(eventloop.network.is_none()); + println!("State is being clean properly"); + } + _ => { + println!("Couldn't fill the TCP send buffer to run this test properly. Try reducing the size of buffer."); + } + } + } + }); + + let mut broker = Broker::new(3004, 0, false).await; + + let token = client + .publish("hello/world", QoS::AtMostOnce, false, [1; 1]) + .await + .unwrap(); + + // Token can resolve as soon as it was processed by eventloop + assert_eq!( + timeout(Duration::from_secs(1), token) + .await + .unwrap() + .unwrap(), + 0 + ); + + // Verify the packet still reached broker + // NOTE: this can't always be guaranteed + let Packet::Publish(Publish { + qos, + topic, + pkid, + payload, + .. + }) = broker.read_packet().await.unwrap() + else { + unreachable!() + }; + assert_eq!(topic, "hello/world"); + assert_eq!(qos, QoS::AtMostOnce); + assert_eq!(payload.to_vec(), [1; 1]); + assert_eq!(pkid, 0); +} + +#[tokio::test] +async fn resolve_on_qos1_ack_from_broker() { + let options = MqttOptions::new("dummy", "127.0.0.1", 3004); + let (client, mut eventloop) = AsyncClient::new(options, 5); + + task::spawn(async move { + let res = run(&mut eventloop, false).await; + if let Err(e) = res { + match e { + ConnectionError::FlushTimeout => { + assert!(eventloop.network.is_none()); + println!("State is being clean properly"); + } + _ => { + println!("Couldn't fill the TCP send buffer to run this test properly. Try reducing the size of buffer."); + } + } + } + }); + + let mut broker = Broker::new(3004, 0, false).await; + + let mut token = client + .publish("hello/world", QoS::AtLeastOnce, false, [1; 1]) + .await + .unwrap(); + + // Token shouldn't resolve before reaching broker + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + let Packet::Publish(Publish { + qos, + topic, + pkid, + payload, + .. + }) = broker.read_packet().await.unwrap() + else { + unreachable!() + }; + assert_eq!(topic, "hello/world"); + assert_eq!(qos, QoS::AtLeastOnce); + assert_eq!(payload.to_vec(), [1; 1]); + assert_eq!(pkid, 1); + + // Token shouldn't resolve until packet is acked + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + // Finally ack the packet + broker.ack(1).await; + + // Token shouldn't resolve until packet is acked + assert_eq!( + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap() + .unwrap(), + 1 + ); +} + +#[tokio::test] +async fn resolve_on_qos2_ack_from_broker() { + let options = MqttOptions::new("dummy", "127.0.0.1", 3004); + let (client, mut eventloop) = AsyncClient::new(options, 5); + + task::spawn(async move { + let res = run(&mut eventloop, false).await; + if let Err(e) = res { + match e { + ConnectionError::FlushTimeout => { + assert!(eventloop.network.is_none()); + println!("State is being clean properly"); + } + _ => { + println!("Couldn't fill the TCP send buffer to run this test properly. Try reducing the size of buffer."); + } + } + } + }); + + let mut broker = Broker::new(3004, 0, false).await; + + let mut token = client + .publish("hello/world", QoS::ExactlyOnce, false, [1; 1]) + .await + .unwrap(); + + // Token shouldn't resolve before reaching broker + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + let Packet::Publish(Publish { + qos, + topic, + pkid, + payload, + .. + }) = broker.read_packet().await.unwrap() + else { + unreachable!() + }; + assert_eq!(topic, "hello/world"); + assert_eq!(qos, QoS::ExactlyOnce); + assert_eq!(payload.to_vec(), [1; 1]); + assert_eq!(pkid, 1); + + // Token shouldn't resolve till publish recorded + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + // Record the publish message + broker.pubrec(1).await; + + // Token shouldn't resolve till publish complete + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + // Complete the publish message ack + broker.pubcomp(1).await; + + // Finally the publish is QoS2 acked + assert_eq!( + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap() + .unwrap(), + 1 + ); +} From 15ffccf291c8902c03e523dc9bfbacb1b32b8daf Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 1 Oct 2024 22:40:58 +0530 Subject: [PATCH 18/30] fix: rm dup import --- rumqttc/src/lib.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index cbb1a4029..05b07a851 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -124,10 +124,7 @@ mod tls; mod websockets; #[cfg(feature = "websocket")] -use std::{ - future::{Future, IntoFuture}, - pin::Pin, -}; +use std::future::IntoFuture; #[cfg(feature = "websocket")] type RequestModifierFn = Arc< From b664fd6d2fba90c3a9159df312e2b2ff6e1966c7 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 1 Oct 2024 23:45:05 +0530 Subject: [PATCH 19/30] test: run on unique port --- rumqttc/tests/reliability.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 9933dfc14..2f559802b 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -591,7 +591,7 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly #[tokio::test] async fn resolve_on_qos0_before_write_to_tcp_buffer() { - let options = MqttOptions::new("dummy", "127.0.0.1", 3004); + let options = MqttOptions::new("dummy", "127.0.0.1", 3005); let (client, mut eventloop) = AsyncClient::new(options, 5); task::spawn(async move { @@ -609,7 +609,7 @@ async fn resolve_on_qos0_before_write_to_tcp_buffer() { } }); - let mut broker = Broker::new(3004, 0, false).await; + let mut broker = Broker::new(3005, 0, false).await; let token = client .publish("hello/world", QoS::AtMostOnce, false, [1; 1]) @@ -645,7 +645,7 @@ async fn resolve_on_qos0_before_write_to_tcp_buffer() { #[tokio::test] async fn resolve_on_qos1_ack_from_broker() { - let options = MqttOptions::new("dummy", "127.0.0.1", 3004); + let options = MqttOptions::new("dummy", "127.0.0.1", 3006); let (client, mut eventloop) = AsyncClient::new(options, 5); task::spawn(async move { @@ -663,7 +663,7 @@ async fn resolve_on_qos1_ack_from_broker() { } }); - let mut broker = Broker::new(3004, 0, false).await; + let mut broker = Broker::new(3006, 0, false).await; let mut token = client .publish("hello/world", QoS::AtLeastOnce, false, [1; 1]) @@ -710,7 +710,7 @@ async fn resolve_on_qos1_ack_from_broker() { #[tokio::test] async fn resolve_on_qos2_ack_from_broker() { - let options = MqttOptions::new("dummy", "127.0.0.1", 3004); + let options = MqttOptions::new("dummy", "127.0.0.1", 3007); let (client, mut eventloop) = AsyncClient::new(options, 5); task::spawn(async move { @@ -728,7 +728,7 @@ async fn resolve_on_qos2_ack_from_broker() { } }); - let mut broker = Broker::new(3004, 0, false).await; + let mut broker = Broker::new(3007, 0, false).await; let mut token = client .publish("hello/world", QoS::ExactlyOnce, false, [1; 1]) From b6d447d983e5f089526bb05878b78e60c40c155b Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 1 Oct 2024 23:59:50 +0530 Subject: [PATCH 20/30] doc: code comments --- rumqttc/src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 05b07a851..b18e1be82 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -225,6 +225,7 @@ impl From for Request { } } +/// Packet Identifier with which Publish/Subscribe/Unsubscribe packets are identified while inflight. pub type Pkid = u16; #[derive(Debug, thiserror::Error)] @@ -235,6 +236,9 @@ pub enum PromiseError { Rejected { reason: String }, } +/// Resolves with [`Pkid`] used against packet when: +/// 1. Packet is acknowldged by the broker, e.g. QoS 1/2 Publish, Subscribe and Unsubscribe +/// 2. QoS 0 packet finishes processing in the [`EventLoop`] pub struct AckPromise { rx: oneshot::Receiver>, } From 23e4d9a270f8030eeea5624eb68dd3f3e0fab756 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 2 Oct 2024 00:37:25 +0530 Subject: [PATCH 21/30] fix: don't expose waiters outside crate --- rumqttc/src/state.rs | 2 +- rumqttc/src/v5/state.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 36755afe7..37ddc462a 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -73,7 +73,7 @@ pub struct MqttState { /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, /// Waiters for publish/subscribe/unsubscribe acknowledgements - pub ack_waiter: Vec>, + ack_waiter: Vec>, } impl MqttState { diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 68b9016e8..97a6def30 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -113,7 +113,7 @@ pub struct MqttState { /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, /// Waiters for publish/subscribe/unsubscribe acknowledgements - pub ack_waiter: Vec>, + ack_waiter: Vec>, } impl MqttState { From 979cb8fcbc7102c2a8923c4d10a6957ad714a458 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 2 Oct 2024 12:15:57 +0530 Subject: [PATCH 22/30] feat: non-blocking `try_resolve` --- rumqttc/examples/ack_promise_sync.rs | 22 +++++++++++++++------- rumqttc/src/lib.rs | 23 +++++++++++++++++++---- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/rumqttc/examples/ack_promise_sync.rs b/rumqttc/examples/ack_promise_sync.rs index 2158d46ad..128be07d7 100644 --- a/rumqttc/examples/ack_promise_sync.rs +++ b/rumqttc/examples/ack_promise_sync.rs @@ -1,7 +1,7 @@ use flume::bounded; -use rumqttc::{Client, MqttOptions, QoS}; +use rumqttc::{Client, MqttOptions, PromiseError, QoS}; use std::error::Error; -use std::thread; +use std::thread::{self, sleep}; use std::time::Duration; fn main() -> Result<(), Box> { @@ -48,7 +48,7 @@ fn main() -> Result<(), Box> { match client .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) .unwrap() - .blocking_wait() + .try_resolve() { Ok(pkid) => println!("Acknowledged Pub({pkid})"), Err(e) => println!("Publish failed: {e:?}"), @@ -84,12 +84,20 @@ fn main() -> Result<(), Box> { tx_clone.send(res).unwrap() }); - let future = client + let mut future = client .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) .unwrap(); - thread::spawn(move || { - let res = future.blocking_wait(); - tx.send(res).unwrap() + thread::spawn(move || loop { + match future.try_resolve() { + Err(PromiseError::Waiting) => { + println!("Promise yet to resolve, retrying"); + sleep(Duration::from_secs(1)); + } + res => { + tx.send(res).unwrap(); + break; + } + } }); while let Ok(res) = rx.recv() { diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index b18e1be82..75af36459 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -136,9 +136,7 @@ type RequestModifierFn = Arc< #[cfg(feature = "proxy")] mod proxy; -pub use client::{ - AsyncClient, Client, ClientError, Connection, Iter, RecvError, RecvTimeoutError, TryRecvError, -}; +pub use client::{AsyncClient, Client, ClientError, Connection, Iter, RecvError, RecvTimeoutError}; pub use eventloop::{ConnectionError, Event, EventLoop}; pub use mqttbytes::v4::*; pub use mqttbytes::*; @@ -147,7 +145,7 @@ use rustls_native_certs::load_native_certs; pub use state::{MqttState, StateError}; #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] pub use tls::Error as TlsError; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, oneshot::error::TryRecvError}; #[cfg(feature = "use-native-tls")] pub use tokio_native_tls; #[cfg(feature = "use-native-tls")] @@ -230,6 +228,8 @@ pub type Pkid = u16; #[derive(Debug, thiserror::Error)] pub enum PromiseError { + #[error("Sender has nothing to send instantly")] + Waiting, #[error("Sender side of channel was dropped")] Disconnected, #[error("Broker rejected the request, reason: {reason}")] @@ -263,6 +263,21 @@ impl AckPromise { .blocking_recv() .map_err(|_| PromiseError::Disconnected)? } + + /// Attempts to check if the broker acknowledged the packet, without blocking the current thread. + /// + /// Returns [`PromiseError::Waiting`] if the packet wasn't acknowledged yet. + /// + /// Multiple calls to this functions can fail with [`PromiseError::Disconnected`] if the promise + /// has already been resolved. + pub fn try_resolve(&mut self) -> Result { + match self.rx.try_recv() { + Ok(Ok(p)) => Ok(p), + Ok(Err(e)) => Err(e), + Err(TryRecvError::Empty) => Err(PromiseError::Waiting), + Err(TryRecvError::Closed) => Err(PromiseError::Disconnected), + } + } } #[derive(Debug)] From 79d5ff83fc14c669917b50e148ce482c47b0cf56 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 2 Oct 2024 12:19:18 +0530 Subject: [PATCH 23/30] doc: comment on `blocking_wait` --- rumqttc/src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 75af36459..2a6ead185 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -258,6 +258,10 @@ impl Future for AckPromise { } impl AckPromise { + /// Blocks on the current thread and waits till the packet is acknowledged by the broker. + /// + /// Returns [`PromiseError::Disconnected`] if the [`EventLoop`] was dropped(usually), + /// [`PromiseError::Rejected`] if the packet acknowledged but not accepted. pub fn blocking_wait(self) -> Result { self.rx .blocking_recv() From 165d1695d12aca95af76914ca7e9ec222978c448 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sun, 6 Oct 2024 22:36:12 +0530 Subject: [PATCH 24/30] refactor: condense and simplify examples --- rumqttc/examples/ack_promise.rs | 70 ++++++++------------------ rumqttc/examples/ack_promise_sync.rs | 75 +++++++++++----------------- rumqttc/examples/ack_promise_v5.rs | 70 ++++++++------------------ 3 files changed, 72 insertions(+), 143 deletions(-) diff --git a/rumqttc/examples/ack_promise.rs b/rumqttc/examples/ack_promise.rs index 7366442f8..55ff7493e 100644 --- a/rumqttc/examples/ack_promise.rs +++ b/rumqttc/examples/ack_promise.rs @@ -6,9 +6,6 @@ use std::time::Duration; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { - pretty_env_logger::init(); - // color_backtrace::install(); - let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); mqttoptions.set_keep_alive(Duration::from_secs(5)); @@ -39,57 +36,34 @@ async fn main() -> Result<(), Box> { } // Publish at all QoS levels and wait for broker acknowledgement - match client - .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) - .await - .unwrap() - .await - { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), - Err(e) => println!("Publish failed: {e:?}"), - } - - match client - .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) - .await - .unwrap() - .await + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), - Err(e) => println!("Publish failed: {e:?}"), + match client + .publish("hello/world", qos, false, vec![1; i]) + .await + .unwrap() + .await + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } } - match client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) - .await - .unwrap() - .await + // Publish with different QoS levels and spawn wait for notification + let mut set = JoinSet::new(); + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), - Err(e) => println!("Publish failed: {e:?}"), + let token = client + .publish("hello/world", qos, false, vec![1; i]) + .await + .unwrap(); + set.spawn(token); } - // Publish and spawn wait for notification - let mut set = JoinSet::new(); - - let future = client - .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) - .await - .unwrap(); - set.spawn(async { future.await }); - - let future = client - .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) - .await - .unwrap(); - set.spawn(async { future.await }); - - let future = client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) - .await - .unwrap(); - set.spawn(async { future.await }); - while let Some(Ok(res)) = set.join_next().await { match res { Ok(pkid) => println!("Acknowledged Pub({pkid})"), diff --git a/rumqttc/examples/ack_promise_sync.rs b/rumqttc/examples/ack_promise_sync.rs index 128be07d7..6d13a9e39 100644 --- a/rumqttc/examples/ack_promise_sync.rs +++ b/rumqttc/examples/ack_promise_sync.rs @@ -5,9 +5,6 @@ use std::thread::{self, sleep}; use std::time::Duration; fn main() -> Result<(), Box> { - pretty_env_logger::init(); - // color_backtrace::install(); - let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); mqttoptions.set_keep_alive(Duration::from_secs(5)); @@ -36,59 +33,43 @@ fn main() -> Result<(), Box> { } // Publish at all QoS levels and wait for broker acknowledgement - match client - .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) - .unwrap() - .blocking_wait() - { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), - Err(e) => println!("Publish failed: {e:?}"), - } - - match client - .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) - .unwrap() - .try_resolve() + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), - Err(e) => println!("Publish failed: {e:?}"), - } - - match client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) - .unwrap() - .blocking_wait() - { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), - Err(e) => println!("Publish failed: {e:?}"), + match client + .publish("hello/world", qos, false, vec![1; i]) + .unwrap() + .blocking_wait() + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } } // Spawn threads for each publish, use channel to notify result let (tx, rx) = bounded(1); - let future = client - .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) - .unwrap(); - let tx_clone = tx.clone(); - thread::spawn(move || { - let res = future.blocking_wait(); - tx_clone.send(res).unwrap() - }); - - let future = client - .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) - .unwrap(); - let tx_clone = tx.clone(); - thread::spawn(move || { - let res = future.blocking_wait(); - tx_clone.send(res).unwrap() - }); + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() + { + let token = client + .publish("hello/world", qos, false, vec![1; i]) + .unwrap(); + let tx = tx.clone(); + thread::spawn(move || { + let res = token.blocking_wait(); + tx.send(res).unwrap() + }); + } - let mut future = client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) + // Try resolving a promise, if it is waiting to resolve, try again after a sleep of 1s + let mut token = client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 4]) .unwrap(); thread::spawn(move || loop { - match future.try_resolve() { + match token.try_resolve() { Err(PromiseError::Waiting) => { println!("Promise yet to resolve, retrying"); sleep(Duration::from_secs(1)); diff --git a/rumqttc/examples/ack_promise_v5.rs b/rumqttc/examples/ack_promise_v5.rs index c2eb26319..8873cf6af 100644 --- a/rumqttc/examples/ack_promise_v5.rs +++ b/rumqttc/examples/ack_promise_v5.rs @@ -6,9 +6,6 @@ use std::time::Duration; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { - pretty_env_logger::init(); - // color_backtrace::install(); - let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); mqttoptions.set_keep_alive(Duration::from_secs(5)); @@ -39,57 +36,34 @@ async fn main() -> Result<(), Box> { } // Publish at all QoS levels and wait for broker acknowledgement - match client - .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) - .await - .unwrap() - .await - { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), - Err(e) => println!("Publish failed: {e:?}"), - } - - match client - .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) - .await - .unwrap() - .await + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), - Err(e) => println!("Publish failed: {e:?}"), + match client + .publish("hello/world", qos, false, vec![1; i]) + .await + .unwrap() + .await + { + Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Err(e) => println!("Publish failed: {e:?}"), + } } - match client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) - .await - .unwrap() - .await + // Publish with different QoS levels and spawn wait for notification + let mut set = JoinSet::new(); + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), - Err(e) => println!("Publish failed: {e:?}"), + let token = client + .publish("hello/world", qos, false, vec![1; i]) + .await + .unwrap(); + set.spawn(token); } - // Publish and spawn wait for notification - let mut set = JoinSet::new(); - - let future = client - .publish("hello/world", QoS::AtMostOnce, false, vec![1; 1]) - .await - .unwrap(); - set.spawn(async { future.await }); - - let future = client - .publish("hello/world", QoS::AtLeastOnce, false, vec![1; 2]) - .await - .unwrap(); - set.spawn(async { future.await }); - - let future = client - .publish("hello/world", QoS::ExactlyOnce, false, vec![1; 3]) - .await - .unwrap(); - set.spawn(async { future.await }); - while let Some(Ok(res)) = set.join_next().await { match res { Ok(pkid) => println!("Acknowledged Pub({pkid})"), From bfeb44d3b40cd5fa3386e69a7786cd60e3980cd2 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 8 Oct 2024 13:41:18 +0530 Subject: [PATCH 25/30] test: working of sub/unsub promises --- rumqttc/tests/broker.rs | 52 +++++++++++-- rumqttc/tests/reliability.rs | 147 +++++++++++++++++++++++++++++++---- 2 files changed, 179 insertions(+), 20 deletions(-) diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index c450b68b4..609b381ed 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -9,9 +9,39 @@ use tokio::{task, time}; use bytes::BytesMut; use flume::{bounded, Receiver, Sender}; -use rumqttc::{Event, Incoming, Outgoing, Packet}; +use rumqttc::{Incoming, Packet}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +#[derive(Debug, PartialEq)] +pub enum Event { + Incoming(Packet), + Outgoing(Outgoing), +} + +#[derive(Debug, PartialEq)] +pub enum Outgoing { + /// Publish packet with packet identifier. 0 implies QoS 0 + Publish(u16), + /// SubAck packet with packet identifier + SubAck(u16), + /// UnsubAck packet with packet identifier + UnsubAck(u16), + /// PubAck packet + PubAck(u16), + /// PubRec packet + PubRec(u16), + /// PubRel packet + PubRel(u16), + /// PubComp packet + PubComp(u16), + /// Ping request packet + PingReq, + /// Ping response packet + PingResp, + /// Disconnect packet + Disconnect, +} + pub struct Broker { pub(crate) framed: Network, pub(crate) incoming: VecDeque, @@ -116,8 +146,8 @@ impl Broker { } } - /// Sends an acknowledgement - pub async fn ack(&mut self, pkid: u16) { + /// Sends a publish acknowledgement + pub async fn puback(&mut self, pkid: u16) { let packet = Packet::PubAck(PubAck::new(pkid)); self.framed.write(packet).await.unwrap(); } @@ -134,6 +164,18 @@ impl Broker { self.framed.write(packet).await.unwrap(); } + /// Sends a subscribe acknowledgement + pub async fn suback(&mut self, pkid: u16, qos: QoS) { + let packet = Packet::SubAck(SubAck::new(pkid, vec![SubscribeReasonCode::Success(qos)])); + self.framed.write(packet).await.unwrap(); + } + + /// Sends an unsubscribe acknowledgement + pub async fn unsuback(&mut self, pkid: u16) { + let packet = Packet::UnsubAck(UnsubAck::new(pkid)); + self.framed.write(packet).await.unwrap(); + } + /// Sends an acknowledgement pub async fn pingresp(&mut self) { let packet = Packet::PingResp; @@ -308,8 +350,8 @@ fn outgoing(packet: &Packet) -> Outgoing { Packet::PubRec(pubrec) => Outgoing::PubRec(pubrec.pkid), Packet::PubRel(pubrel) => Outgoing::PubRel(pubrel.pkid), Packet::PubComp(pubcomp) => Outgoing::PubComp(pubcomp.pkid), - Packet::Subscribe(subscribe) => Outgoing::Subscribe(subscribe.pkid), - Packet::Unsubscribe(unsubscribe) => Outgoing::Unsubscribe(unsubscribe.pkid), + Packet::SubAck(suback) => Outgoing::SubAck(suback.pkid), + Packet::UnsubAck(unsuback) => Outgoing::UnsubAck(unsuback.pkid), Packet::PingReq => Outgoing::PingReq, Packet::PingResp => Outgoing::PingResp, Packet::Disconnect => Outgoing::Disconnect, diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 2f559802b..2915d3e8e 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -179,7 +179,7 @@ async fn some_outgoing_and_no_incoming_should_trigger_pings_on_time() { loop { let event = broker.tick().await; - if event == Event::Incoming(Incoming::PingReq) { + if event == broker::Event::Incoming(Incoming::PingReq) { // wait for 3 pings count += 1; if count == 3 { @@ -218,7 +218,7 @@ async fn some_incoming_and_no_outgoing_should_trigger_pings_on_time() { loop { let event = broker.tick().await; - if event == Event::Incoming(Incoming::PingReq) { + if event == broker::Event::Incoming(Incoming::PingReq) { // wait for 3 pings count += 1; if count == 3 { @@ -320,12 +320,12 @@ async fn requests_are_recovered_after_inflight_queue_size_falls_below_max() { assert!(broker.read_publish().await.is_none()); // ack packet 1 and client would produce packet 4 - broker.ack(1).await; + broker.puback(1).await; assert!(broker.read_publish().await.is_some()); assert!(broker.read_publish().await.is_none()); // ack packet 2 and client would produce packet 5 - broker.ack(2).await; + broker.puback(2).await; assert!(broker.read_publish().await.is_some()); assert!(broker.read_publish().await.is_none()); } @@ -353,18 +353,18 @@ async fn packet_id_collisions_are_detected_and_flow_control_is_applied() { } // out of order ack - broker.ack(3).await; - broker.ack(4).await; + broker.puback(3).await; + broker.puback(4).await; time::sleep(Duration::from_secs(5)).await; - broker.ack(1).await; - broker.ack(2).await; + broker.puback(1).await; + broker.puback(2).await; // read and ack remaining packets in order for i in 5..=15 { let packet = broker.read_publish().await; let packet = packet.unwrap(); assert_eq!(packet.payload[0], i); - broker.ack(packet.pkid).await; + broker.puback(packet.pkid).await; } time::sleep(Duration::from_secs(10)).await; @@ -376,7 +376,7 @@ async fn packet_id_collisions_are_detected_and_flow_control_is_applied() { // Poll until there is collision. loop { match eventloop.poll().await.unwrap() { - Event::Outgoing(Outgoing::AwaitAck(1)) => break, + rumqttc::Event::Outgoing(rumqttc::Outgoing::AwaitAck(1)) => break, v => { println!("Poll = {v:?}"); continue; @@ -390,7 +390,7 @@ async fn packet_id_collisions_are_detected_and_flow_control_is_applied() { println!("Poll = {event:?}"); match event { - Event::Outgoing(Outgoing::Publish(ack)) => { + rumqttc::Event::Outgoing(rumqttc::Outgoing::Publish(ack)) => { if ack == 1 { let elapsed = start.elapsed().as_millis() as i64; let deviation_millis: i64 = (5000 - elapsed).abs(); @@ -466,7 +466,7 @@ async fn next_poll_after_connect_failure_reconnects() { } match eventloop.poll().await { - Ok(Event::Incoming(Packet::ConnAck(ConnAck { + Ok(rumqttc::Event::Incoming(Packet::ConnAck(ConnAck { code: ConnectReturnCode::Success, session_present: false, }))) => (), @@ -498,7 +498,7 @@ async fn reconnection_resumes_from_the_previous_state() { for i in 1..=2 { let packet = broker.read_publish().await.unwrap(); assert_eq!(i, packet.payload[0]); - broker.ack(packet.pkid).await; + broker.puback(packet.pkid).await; } // NOTE: An interesting thing to notice here is that reassigning a new broker @@ -512,7 +512,7 @@ async fn reconnection_resumes_from_the_previous_state() { for i in 3..=4 { let packet = broker.read_publish().await.unwrap(); assert_eq!(i, packet.payload[0]); - broker.ack(packet.pkid).await; + broker.puback(packet.pkid).await; } } @@ -696,7 +696,7 @@ async fn resolve_on_qos1_ack_from_broker() { .unwrap_err(); // Finally ack the packet - broker.ack(1).await; + broker.puback(1).await; // Token shouldn't resolve until packet is acked assert_eq!( @@ -780,3 +780,120 @@ async fn resolve_on_qos2_ack_from_broker() { 1 ); } + +#[tokio::test] +async fn resolve_on_sub_ack_from_broker() { + let options = MqttOptions::new("dummy", "127.0.0.1", 3006); + let (client, mut eventloop) = AsyncClient::new(options, 5); + + task::spawn(async move { + let res = run(&mut eventloop, false).await; + if let Err(e) = res { + match e { + ConnectionError::FlushTimeout => { + assert!(eventloop.network.is_none()); + println!("State is being clean properly"); + } + _ => { + println!("Couldn't fill the TCP send buffer to run this test properly. Try reducing the size of buffer."); + } + } + } + }); + + let mut broker = Broker::new(3006, 0, false).await; + + let mut token = client + .subscribe("hello/world", QoS::AtLeastOnce) + .await + .unwrap(); + + // Token shouldn't resolve before reaching broker + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + let Packet::Subscribe(Subscribe { pkid, filters, .. }) = broker.read_packet().await.unwrap() + else { + unreachable!() + }; + assert_eq!( + filters, + [SubscribeFilter { + path: "hello/world".to_owned(), + qos: QoS::AtLeastOnce + }] + ); + assert_eq!(pkid, 1); + + // Token shouldn't resolve until packet is acked + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + // Finally ack the packet + broker.suback(1, QoS::AtLeastOnce).await; + + // Token shouldn't resolve until packet is acked + assert_eq!( + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap() + .unwrap(), + 1 + ); +} + +#[tokio::test] +async fn resolve_on_unsub_ack_from_broker() { + let options = MqttOptions::new("dummy", "127.0.0.1", 3006); + let (client, mut eventloop) = AsyncClient::new(options, 5); + + task::spawn(async move { + let res = run(&mut eventloop, false).await; + if let Err(e) = res { + match e { + ConnectionError::FlushTimeout => { + assert!(eventloop.network.is_none()); + println!("State is being clean properly"); + } + _ => { + println!("Couldn't fill the TCP send buffer to run this test properly. Try reducing the size of buffer."); + } + } + } + }); + + let mut broker = Broker::new(3006, 0, false).await; + + let mut token = client.unsubscribe("hello/world").await.unwrap(); + + // Token shouldn't resolve before reaching broker + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + let Packet::Unsubscribe(Unsubscribe { topics, pkid, .. }) = broker.read_packet().await.unwrap() + else { + unreachable!() + }; + assert_eq!(topics, vec!["hello/world"]); + assert_eq!(pkid, 1); + + // Token shouldn't resolve until packet is acked + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + // Finally ack the packet + broker.unsuback(1).await; + + // Token shouldn't resolve until packet is acked + assert_eq!( + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap() + .unwrap(), + 1 + ); +} From a0e678bf10d49152bead59b3dd79bb9973c4559f Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Fri, 15 Nov 2024 23:25:00 +0530 Subject: [PATCH 26/30] feat: tokens for all requests (#921) --- rumqttc/examples/ack_promise_sync.rs | 14 +- rumqttc/src/client.rs | 278 ++++++++++---------- rumqttc/src/eventloop.rs | 24 +- rumqttc/src/lib.rs | 138 +--------- rumqttc/src/state.rs | 305 ++++++++++++---------- rumqttc/src/tokens.rs | 118 +++++++++ rumqttc/src/v5/client.rs | 329 ++++++++++++------------ rumqttc/src/v5/eventloop.rs | 23 +- rumqttc/src/v5/mod.rs | 29 +-- rumqttc/src/v5/state.rs | 362 +++++++++++++++------------ 10 files changed, 849 insertions(+), 771 deletions(-) create mode 100644 rumqttc/src/tokens.rs diff --git a/rumqttc/examples/ack_promise_sync.rs b/rumqttc/examples/ack_promise_sync.rs index 6d13a9e39..506ee486c 100644 --- a/rumqttc/examples/ack_promise_sync.rs +++ b/rumqttc/examples/ack_promise_sync.rs @@ -1,5 +1,5 @@ use flume::bounded; -use rumqttc::{Client, MqttOptions, PromiseError, QoS}; +use rumqttc::{Client, MqttOptions, QoS, TokenError}; use std::error::Error; use std::thread::{self, sleep}; use std::time::Duration; @@ -26,7 +26,7 @@ fn main() -> Result<(), Box> { match client .subscribe("hello/world", QoS::AtMostOnce) .unwrap() - .blocking_wait() + .wait() { Ok(pkid) => println!("Acknowledged Sub({pkid})"), Err(e) => println!("Subscription failed: {e:?}"), @@ -40,7 +40,7 @@ fn main() -> Result<(), Box> { match client .publish("hello/world", qos, false, vec![1; i]) .unwrap() - .blocking_wait() + .wait() { Ok(pkid) => println!("Acknowledged Pub({pkid})"), Err(e) => println!("Publish failed: {e:?}"), @@ -59,7 +59,7 @@ fn main() -> Result<(), Box> { .unwrap(); let tx = tx.clone(); thread::spawn(move || { - let res = token.blocking_wait(); + let res = token.wait(); tx.send(res).unwrap() }); } @@ -69,8 +69,8 @@ fn main() -> Result<(), Box> { .publish("hello/world", QoS::AtMostOnce, false, vec![1; 4]) .unwrap(); thread::spawn(move || loop { - match token.try_resolve() { - Err(PromiseError::Waiting) => { + match token.check() { + Err(TokenError::Waiting) => { println!("Promise yet to resolve, retrying"); sleep(Duration::from_secs(1)); } @@ -89,7 +89,7 @@ fn main() -> Result<(), Box> { } // Unsubscribe and wait for broker acknowledgement - match client.unsubscribe("hello/world").unwrap().blocking_wait() { + match client.unsubscribe("hello/world").unwrap().wait() { Ok(pkid) => println!("Acknowledged Unsub({pkid})"), Err(e) => println!("Unsubscription failed: {e:?}"), } diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index 47b9da0c8..ffab94f0c 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -3,9 +3,9 @@ use std::time::Duration; use crate::mqttbytes::{v4::*, QoS}; +use crate::tokens::{NoResponse, Resolver, Token}; use crate::{ - valid_filter, valid_topic, AckPromise, ConnectionError, Event, EventLoop, MqttOptions, - PromiseTx, Request, + valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, Pkid, Request, }; use bytes::Bytes; @@ -23,15 +23,15 @@ pub enum ClientError { TryRequest(Request), } -impl From)>> for ClientError { - fn from(e: SendError<(Request, Option)>) -> Self { - Self::Request(e.into_inner().0) +impl From> for ClientError { + fn from(e: SendError) -> Self { + Self::Request(e.into_inner()) } } -impl From)>> for ClientError { - fn from(e: TrySendError<(Request, Option)>) -> Self { - Self::TryRequest(e.into_inner().0) +impl From> for ClientError { + fn from(e: TrySendError) -> Self { + Self::TryRequest(e.into_inner()) } } @@ -44,7 +44,7 @@ impl From)>> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender<(Request, Option)>, + request_tx: Sender, } impl AsyncClient { @@ -64,7 +64,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender<(Request, Option)>) -> AsyncClient { + pub fn from_senders(request_tx: Sender) -> AsyncClient { AsyncClient { request_tx } } @@ -75,24 +75,22 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result, ClientError> where S: Into, V: Into>, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish, resolver); if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); + return Err(ClientError::Request(request)); } - self.request_tx - .send_async((publish, Some(promise_tx))) - .await?; + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } /// Attempts to send a MQTT Publish to the `EventLoop`. @@ -102,43 +100,44 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result, ClientError> where S: Into, V: Into>, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish, resolver); if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); + return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send((publish, Some(promise_tx)))?; + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); - + pub async fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.request_tx.send_async((ack, None)).await?; + self.request_tx.send_async(ack).await?; } - Ok(()) + Ok(token) } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.request_tx.try_send((ack, None))?; + self.request_tx.try_send(ack)?; } - Ok(()) + Ok(token) } /// Sends a MQTT Publish to the `EventLoop` @@ -148,19 +147,17 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result + ) -> Result, ClientError> where S: Into, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); - self.request_tx - .send_async((publish, Some(promise_tx))) - .await?; + let request = Request::Publish(publish, resolver); + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -168,17 +165,18 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.request_tx - .send_async((subscribe.into(), Some(promise_tx))) - .await?; + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } /// Attempts to send a MQTT Subscribe to the `EventLoop` @@ -186,94 +184,101 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } - self.request_tx - .try_send((subscribe.into(), Some(promise_tx)))?; + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub async fn subscribe_many(&self, topics: T) -> Result + pub async fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.request_tx - .send_async((subscribe.into(), Some(promise_tx))) - .await?; + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } - self.request_tx - .try_send((subscribe.into(), Some(promise_tx)))?; + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub async fn unsubscribe>(&self, topic: S) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + pub async fn unsubscribe>(&self, topic: S) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); - self.request_tx - .send_async((unsubscribe.into(), Some(promise_tx))) - .await?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); - self.request_tx - .try_send((unsubscribe.into(), Some(promise_tx)))?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT disconnect to the `EventLoop` - pub async fn disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect(Disconnect); - self.request_tx.send_async((request, None)).await?; + pub async fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.request_tx.send_async(request).await?; - Ok(()) + Ok(token) } /// Attempts to send a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect(Disconnect); - self.request_tx.try_send((request, None))?; + pub fn try_disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.request_tx.try_send(request)?; - Ok(()) + Ok(token) } } -fn get_ack_req(publish: &Publish) -> Option { +fn get_ack_req(publish: &Publish, resolver: Resolver<()>) -> Option { let ack = match publish.qos { - QoS::AtMostOnce => return None, - QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid)), - QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)), + QoS::AtMostOnce => { + resolver.resolve(()); + return None; + } + QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid), resolver), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid), resolver), }; Some(ack) } @@ -313,7 +318,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender<(Request, Option)>) -> Client { + pub fn from_sender(request_tx: Sender) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -326,22 +331,22 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result, ClientError> where S: Into, V: Into>, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.client.request_tx.send((publish, Some(promise_tx)))?; + self.client.request_tx.send(publish)?; - Ok(promise) + Ok(token) } pub fn try_publish( @@ -350,7 +355,7 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result + ) -> Result, ClientError> where S: Into, V: Into>, @@ -359,18 +364,18 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); - + pub fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.client.request_tx.send((ack, None))?; + self.client.request_tx.send(ack)?; } - Ok(()) + Ok(token) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { self.client.try_ack(publish) } @@ -379,17 +384,17 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.client - .request_tx - .send((subscribe.into(), Some(promise_tx)))?; + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -397,28 +402,29 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result, ClientError> { self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn subscribe_many(&self, topics: T) -> Result + pub fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.client - .request_tx - .send((subscribe.into(), Some(promise_tx)))?; + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -426,31 +432,31 @@ impl Client { } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn unsubscribe>(&self, topic: S) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + pub fn unsubscribe>(&self, topic: S) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); - self.client.request_tx.send((request, Some(promise_tx)))?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result { - let (promise_tx, promise) = PromiseTx::new(); - let request = Request::Disconnect(Disconnect); - self.client.request_tx.send((request, Some(promise_tx)))?; + pub fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { + pub fn try_disconnect(&self) -> Result, ClientError> { self.client.try_disconnect() } } diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index b98a390b2..00063bd0f 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -1,5 +1,5 @@ use crate::{framed::Network, Transport}; -use crate::{Incoming, MqttState, NetworkOptions, Packet, PromiseTx, Request, StateError}; +use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError}; use crate::{MqttOptions, Outgoing}; use crate::framed::AsyncReadWrite; @@ -75,11 +75,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver<(Request, Option)>, + requests_rx: Receiver, /// Requests handle to send requests - pub(crate) requests_tx: Sender<(Request, Option)>, + pub(crate) requests_tx: Sender, /// Pending packets from last session - pub pending: VecDeque<(Request, Option)>, + pub pending: VecDeque, /// Network connection to the broker pub network: Option, /// Keep alive time @@ -132,9 +132,9 @@ impl EventLoop { // drain requests from channel which weren't yet received let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect(); - requests_in_channel.retain(|(request, _)| { + requests_in_channel.retain(|request| { match request { - Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack + Request::PubAck(..) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, } }); @@ -241,8 +241,8 @@ impl EventLoop { &self.requests_rx, self.mqtt_options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok((request, tx)) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request, tx)? { + Ok(request) => { + if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -260,7 +260,7 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq), None)? { + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { @@ -282,10 +282,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque<(Request, Option)>, - rx: &Receiver<(Request, Option)>, + pending: &mut VecDeque, + rx: &Receiver, pending_throttle: Duration, - ) -> Result<(Request, Option), ConnectionError> { + ) -> Result { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .pop_front() AFTER sleep() otherwise we would have diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 2a6ead185..b707782db 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -98,12 +98,7 @@ #[macro_use] extern crate log; -use std::{ - fmt::{self, Debug, Formatter}, - future::Future, - pin::Pin, - task::{Context, Poll}, -}; +use std::fmt::{self, Debug, Formatter}; #[cfg(any(feature = "use-rustls", feature = "websocket"))] use std::sync::Arc; @@ -135,6 +130,7 @@ type RequestModifierFn = Arc< #[cfg(feature = "proxy")] mod proxy; +mod tokens; pub use client::{AsyncClient, Client, ClientError, Connection, Iter, RecvError, RecvTimeoutError}; pub use eventloop::{ConnectionError, Event, EventLoop}; @@ -145,7 +141,8 @@ use rustls_native_certs::load_native_certs; pub use state::{MqttState, StateError}; #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] pub use tls::Error as TlsError; -use tokio::sync::{oneshot, oneshot::error::TryRecvError}; +use tokens::Resolver; +pub use tokens::{Token, TokenError}; #[cfg(feature = "use-native-tls")] pub use tokio_native_tls; #[cfg(feature = "use-native-tls")] @@ -189,130 +186,21 @@ pub enum Outgoing { /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Debug)] pub enum Request { - Publish(Publish), - PubAck(PubAck), - PubRec(PubRec), - PubComp(PubComp), - PubRel(PubRel), - PingReq(PingReq), - PingResp(PingResp), - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - Disconnect(Disconnect), -} - -impl From for Request { - fn from(publish: Publish) -> Request { - Request::Publish(publish) - } -} - -impl From for Request { - fn from(subscribe: Subscribe) -> Request { - Request::Subscribe(subscribe) - } -} - -impl From for Request { - fn from(unsubscribe: Unsubscribe) -> Request { - Request::Unsubscribe(unsubscribe) - } + Publish(Publish, Resolver), + PubAck(PubAck, Resolver<()>), + PubRec(PubRec, Resolver<()>), + PubRel(PubRel, Resolver), + Subscribe(Subscribe, Resolver), + Unsubscribe(Unsubscribe, Resolver), + Disconnect(Resolver<()>), + PingReq, } /// Packet Identifier with which Publish/Subscribe/Unsubscribe packets are identified while inflight. pub type Pkid = u16; -#[derive(Debug, thiserror::Error)] -pub enum PromiseError { - #[error("Sender has nothing to send instantly")] - Waiting, - #[error("Sender side of channel was dropped")] - Disconnected, - #[error("Broker rejected the request, reason: {reason}")] - Rejected { reason: String }, -} - -/// Resolves with [`Pkid`] used against packet when: -/// 1. Packet is acknowldged by the broker, e.g. QoS 1/2 Publish, Subscribe and Unsubscribe -/// 2. QoS 0 packet finishes processing in the [`EventLoop`] -pub struct AckPromise { - rx: oneshot::Receiver>, -} - -impl Future for AckPromise { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let polled = unsafe { self.map_unchecked_mut(|s| &mut s.rx) }.poll(cx); - - match polled { - Poll::Ready(Ok(p)) => Poll::Ready(p), - Poll::Ready(Err(_)) => Poll::Ready(Err(PromiseError::Disconnected)), - Poll::Pending => Poll::Pending, - } - } -} - -impl AckPromise { - /// Blocks on the current thread and waits till the packet is acknowledged by the broker. - /// - /// Returns [`PromiseError::Disconnected`] if the [`EventLoop`] was dropped(usually), - /// [`PromiseError::Rejected`] if the packet acknowledged but not accepted. - pub fn blocking_wait(self) -> Result { - self.rx - .blocking_recv() - .map_err(|_| PromiseError::Disconnected)? - } - - /// Attempts to check if the broker acknowledged the packet, without blocking the current thread. - /// - /// Returns [`PromiseError::Waiting`] if the packet wasn't acknowledged yet. - /// - /// Multiple calls to this functions can fail with [`PromiseError::Disconnected`] if the promise - /// has already been resolved. - pub fn try_resolve(&mut self) -> Result { - match self.rx.try_recv() { - Ok(Ok(p)) => Ok(p), - Ok(Err(e)) => Err(e), - Err(TryRecvError::Empty) => Err(PromiseError::Waiting), - Err(TryRecvError::Closed) => Err(PromiseError::Disconnected), - } - } -} - -#[derive(Debug)] -pub struct PromiseTx { - tx: oneshot::Sender>, -} - -impl PromiseTx { - fn new() -> (PromiseTx, AckPromise) { - let (tx, rx) = oneshot::channel(); - - (PromiseTx { tx }, AckPromise { rx }) - } - - fn resolve(self, pkid: Pkid) { - if self.tx.send(Ok(pkid)).is_err() { - trace!("Promise was dropped") - } - } - - fn fail(self, reason: String) { - if self - .tx - .send(Err(PromiseError::Rejected { reason })) - .is_err() - { - trace!("Promise was dropped") - } - } -} - /// Transport methods. Defaults to TCP. #[derive(Clone)] pub enum Transport { diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 37ddc462a..4698996a7 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,9 +1,10 @@ -use crate::{Event, Incoming, Outgoing, PromiseTx, Request}; +use crate::Pkid; +use crate::{tokens::Resolver, Event, Incoming, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; use fixedbitset::FixedBitSet; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; /// Errors during state handling @@ -67,13 +68,17 @@ pub struct MqttState { /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, /// Last collision due to broker not acking in order - pub collision: Option<(Publish, Option)>, + pub collision: Option<(Publish, Resolver)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, - /// Waiters for publish/subscribe/unsubscribe acknowledgements - ack_waiter: Vec>, + /// Waiters for publish acknowledgements + pub_ack_waiter: HashMap>, + /// Waiters for subscribe acknowledgements + sub_ack_waiter: HashMap>, + /// Waiters for unsubscribe acknowledgements + unsub_ack_waiter: HashMap>, } impl MqttState { @@ -98,12 +103,14 @@ impl MqttState { // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), manual_acks, - ack_waiter: (0..max_inflight as usize + 1).map(|_| None).collect(), + pub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + sub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + unsub_ack_waiter: HashMap::with_capacity(max_inflight as usize), } } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec<(Request, Option)> { + pub fn clean(&mut self) -> Vec { let mut pending = Vec::with_capacity(100); let (first_half, second_half) = self .outgoing_pub @@ -111,17 +118,18 @@ impl MqttState { for publish in second_half.iter_mut().chain(first_half) { if let Some(publish) = publish.take() { - let tx = self.ack_waiter[publish.pkid as usize].take(); - let request = Request::Publish(publish); - pending.push((request, tx)); + let resolver = self.pub_ack_waiter.remove(&publish.pkid).unwrap(); + let request = Request::Publish(publish, resolver); + pending.push(request); } } // remove and collect pending releases for pkid in self.outgoing_rel.ones() { - let tx = self.ack_waiter[pkid].take(); - let request = Request::PubRel(PubRel::new(pkid as u16)); - pending.push((request, tx)); + let pkid = pkid as u16; + let resolver = self.pub_ack_waiter.remove(&pkid).unwrap(); + let request = Request::PubRel(PubRel::new(pkid), resolver); + pending.push(request); } self.outgoing_rel.clear(); @@ -143,18 +151,29 @@ impl MqttState { pub fn handle_outgoing_packet( &mut self, request: Request, - tx: Option, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish, tx)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, tx)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, - Request::PingReq(_) => self.outgoing_ping()?, - Request::Disconnect(_) => self.outgoing_disconnect()?, - Request::PubAck(puback) => self.outgoing_puback(puback)?, - Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, - _ => unimplemented!(), + Request::Publish(publish, resolver) => self.outgoing_publish(publish, resolver)?, + Request::PubRel(pubrel, resolver) => self.outgoing_pubrel(pubrel, resolver)?, + Request::Subscribe(subscribe, resolver) => { + self.outgoing_subscribe(subscribe, resolver)? + } + Request::Unsubscribe(unsubscribe, resolver) => { + self.outgoing_unsubscribe(unsubscribe, resolver)? + } + Request::PingReq => self.outgoing_ping()?, + Request::Disconnect(resolver) => { + resolver.resolve(()); + self.outgoing_disconnect()? + } + Request::PubAck(puback, resolver) => { + resolver.resolve(()); + self.outgoing_puback(puback)? + } + Request::PubRec(pubrec, resolver) => { + resolver.resolve(()); + self.outgoing_pubrec(pubrec)? + } }; self.last_outgoing = Instant::now(); @@ -171,7 +190,7 @@ impl MqttState { ) -> Result, StateError> { self.events.push_back(Event::Incoming(packet.clone())); - let outgoing = match &packet { + let outgoing = match packet { Incoming::PingResp => self.handle_incoming_pingresp()?, Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, @@ -190,26 +209,18 @@ impl MqttState { Ok(outgoing) } - fn is_pkid_of_publish(&self, pkid: u16) -> bool { - self.outgoing_pub[pkid as usize].is_some() || self.outgoing_rel.contains(pkid as usize) - } - - fn handle_incoming_suback(&mut self, suback: &SubAck) -> Result, StateError> { - // Expected ack for a subscribe packet, not a publish packet - if self.is_pkid_of_publish(suback.pkid) { + fn handle_incoming_suback(&mut self, suback: SubAck) -> Result, StateError> { + let Some(resolver) = self.sub_ack_waiter.remove(&suback.pkid) else { return Err(StateError::Unsolicited(suback.pkid)); - } - - if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { - if suback - .return_codes - .iter() - .all(|r| matches!(r, SubscribeReasonCode::Success(_))) - { - tx.resolve(suback.pkid); - } else { - tx.fail(format!("{:?}", suback.return_codes)); - } + }; + if suback + .return_codes + .iter() + .all(|r| matches!(r, SubscribeReasonCode::Success(_))) + { + resolver.resolve(suback.pkid); + } else { + resolver.reject(suback.return_codes); } Ok(None) @@ -217,23 +228,20 @@ impl MqttState { fn handle_incoming_unsuback( &mut self, - unsuback: &UnsubAck, + unsuback: UnsubAck, ) -> Result, StateError> { - // Expected ack for a unsubscribe packet, not a publish packet - if self.is_pkid_of_publish(unsuback.pkid) { + let Some(resolver) = self.unsub_ack_waiter.remove(&unsuback.pkid) else { return Err(StateError::Unsolicited(unsuback.pkid)); - } + }; - if let Some(tx) = self.ack_waiter[unsuback.pkid as usize].take() { - tx.resolve(unsuback.pkid); - } + resolver.resolve(unsuback.pkid); Ok(None) } /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish(&mut self, publish: &Publish) -> Result, StateError> { + fn handle_incoming_publish(&mut self, publish: Publish) -> Result, StateError> { let qos = publish.qos; match qos { @@ -258,7 +266,7 @@ impl MqttState { } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { + fn handle_incoming_puback(&mut self, puback: PubAck) -> Result, StateError> { let p = self .outgoing_pub .get_mut(puback.pkid as usize) @@ -271,28 +279,32 @@ impl MqttState { return Err(StateError::Unsolicited(puback.pkid)); } - if let Some(tx) = self.ack_waiter[puback.pkid as usize].take() { - // Resolve promise for QoS 1 - tx.resolve(puback.pkid); - } + let Some(resolver) = self.pub_ack_waiter.remove(&puback.pkid) else { + return Err(StateError::Unsolicited(puback.pkid)); + }; + + // Resolve promise for QoS 1 + resolver.resolve(puback.pkid); self.inflight -= 1; - let packet = self.check_collision(puback.pkid).map(|(publish, tx)| { - self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); - self.inflight += 1; + let packet = self + .check_collision(puback.pkid) + .map(|(publish, resolver)| { + self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + self.inflight += 1; - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - self.ack_waiter[puback.pkid as usize] = tx; + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + self.pub_ack_waiter.insert(publish.pkid, resolver); - Packet::Publish(publish) - }); + Packet::Publish(publish) + }); Ok(packet) } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { if self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -313,7 +325,7 @@ impl MqttState { Ok(Some(Packet::PubRel(pubrel))) } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { if !self.incoming_pub.contains(pubrel.pkid as usize) { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); return Err(StateError::Unsolicited(pubrel.pkid)); @@ -327,27 +339,31 @@ impl MqttState { Ok(Some(Packet::PubComp(pubcomp))) } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { + fn handle_incoming_pubcomp(&mut self, pubcomp: PubComp) -> Result, StateError> { if !self.outgoing_rel.contains(pubcomp.pkid as usize) { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); } - if let Some(tx) = self.ack_waiter[pubcomp.pkid as usize].take() { - // Resolve promise for QoS 2 - tx.resolve(pubcomp.pkid); - } + let Some(resolver) = self.pub_ack_waiter.remove(&pubcomp.pkid) else { + return Err(StateError::Unsolicited(pubcomp.pkid)); + }; + + // Resolve promise for QoS 2 + resolver.resolve(pubcomp.pkid); self.outgoing_rel.set(pubcomp.pkid as usize, false); self.inflight -= 1; - let packet = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - self.ack_waiter[pubcomp.pkid as usize] = tx; + let packet = self + .check_collision(pubcomp.pkid) + .map(|(publish, resolver)| { + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + self.pub_ack_waiter.insert(publish.pkid, resolver); - Packet::Publish(publish) - }); + Packet::Publish(publish) + }); Ok(packet) } @@ -363,7 +379,7 @@ impl MqttState { fn outgoing_publish( &mut self, mut publish: Publish, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { @@ -378,7 +394,7 @@ impl MqttState { .is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some((publish, tx)); + self.collision = Some((publish, resolver)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -399,20 +415,26 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); - match (publish.qos, tx) { - (QoS::AtMostOnce, Some(tx)) => tx.resolve(publish.pkid), - (_, tx) => self.ack_waiter[publish.pkid as usize] = tx, + if publish.qos == QoS::AtMostOnce { + resolver.resolve(publish.pkid); + } else { + self.pub_ack_waiter.insert(publish.pkid, resolver); } Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { + fn outgoing_pubrel( + &mut self, + pubrel: PubRel, + resolver: Resolver, + ) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); + self.pub_ack_waiter.insert(pubrel.pkid, resolver); Ok(Some(Packet::PubRel(pubrel))) } @@ -469,7 +491,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -485,7 +507,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); self.events.push_back(event); - self.ack_waiter[subscription.pkid as usize] = tx; + self.sub_ack_waiter.insert(subscription.pkid, resolver); Ok(Some(Packet::Subscribe(subscription))) } @@ -493,7 +515,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -505,7 +527,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); self.events.push_back(event); - self.ack_waiter[unsub.pkid as usize] = tx; + self.unsub_ack_waiter.insert(unsub.pkid, resolver); Ok(Some(Packet::Unsubscribe(unsub))) } @@ -519,7 +541,7 @@ impl MqttState { Ok(Some(Packet::Disconnect)) } - fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option)> { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Resolver)> { if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); @@ -568,7 +590,8 @@ impl MqttState { mod test { use super::{MqttState, StateError}; use crate::mqttbytes::v4::*; - use crate::mqttbytes::*; + use crate::tokens::Resolver; + use crate::{mqttbytes::*, Pkid}; use crate::{Event, Incoming, Outgoing, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { @@ -619,7 +642,8 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -627,12 +651,14 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -640,12 +666,14 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -659,9 +687,9 @@ mod test { let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); // only qos2 publish should be add to queue assert!(mqtt.incoming_pub.contains(3)); @@ -676,9 +704,9 @@ mod test { let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] { assert_eq!(pkid, 2); @@ -703,9 +731,9 @@ mod test { let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); assert!(mqtt.incoming_pub.contains(3)); @@ -717,7 +745,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); + let packet = mqtt.handle_incoming_publish(publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), _ => panic!("Invalid network request: {:?}", packet), @@ -731,14 +759,16 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1, None).unwrap(); - mqtt.outgoing_publish(publish2, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish1, resolver).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish2, resolver).unwrap(); assert_eq!(mqtt.inflight, 2); - mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(1)).unwrap(); assert_eq!(mqtt.inflight, 1); - mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(2)).unwrap(); assert_eq!(mqtt.inflight, 0); assert!(mqtt.outgoing_pub[1].is_none()); @@ -749,7 +779,7 @@ mod test { fn incoming_puback_with_pkid_greater_than_max_inflight_should_be_handled_gracefully() { let mut mqtt = build_mqttstate(); - let got = mqtt.handle_incoming_puback(&PubAck::new(101)).unwrap_err(); + let got = mqtt.handle_incoming_puback(PubAck::new(101)).unwrap_err(); match got { StateError::Unsolicited(pkid) => assert_eq!(pkid, 101), @@ -764,10 +794,12 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1, None); - let _publish_out = mqtt.outgoing_publish(publish2, None); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish1, resolver); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish2, resolver); - mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(2)).unwrap(); assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 @@ -783,14 +815,15 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - let packet = mqtt.outgoing_publish(publish, None).unwrap().unwrap(); + let resolver = Resolver::mock(); + let packet = mqtt.outgoing_publish(publish, resolver).unwrap().unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } let packet = mqtt - .handle_incoming_pubrec(&PubRec::new(1)) + .handle_incoming_pubrec(PubRec::new(1)) .unwrap() .unwrap(); match packet { @@ -804,14 +837,14 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); + let packet = mqtt.handle_incoming_publish(publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } let packet = mqtt - .handle_incoming_pubrel(&PubRel::new(1)) + .handle_incoming_pubrel(PubRel::new(1)) .unwrap() .unwrap(); match packet { @@ -825,10 +858,11 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish, None).unwrap(); - mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(1)).unwrap(); - mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + mqtt.handle_incoming_pubcomp(PubComp::new(1)).unwrap(); assert_eq!(mqtt.inflight, 0); } @@ -839,7 +873,8 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish), None) + let resolver = Resolver::mock(); + mqtt.handle_outgoing_packet(Request::Publish(publish, resolver)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) .unwrap(); @@ -868,8 +903,8 @@ mod test { fn clean_is_calculating_pending_correctly() { let mut mqtt = build_mqttstate(); - fn build_outgoing_pub() -> Vec> { - vec![ + fn build_outgoing_pub(state: &mut MqttState) { + state.outgoing_pub = vec![ None, Some(Publish { dup: false, @@ -905,39 +940,47 @@ mod test { pkid: 6, payload: "".into(), }), - ] + ]; + for (i, _) in state + .outgoing_pub + .iter() + .enumerate() + .filter(|(_, p)| p.is_some()) + { + state.pub_ack_waiter.insert(i as Pkid, Resolver::mock()); + } } - mqtt.outgoing_pub = build_outgoing_pub(); + build_outgoing_pub(&mut mqtt); mqtt.last_puback = 3; let requests = mqtt.clean(); let res = vec![6, 1, 2, 3]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = &req.0 { + if let Request::Publish(publish, _) = &req { assert_eq!(publish.pkid, idx); } else { unreachable!() } } - mqtt.outgoing_pub = build_outgoing_pub(); + build_outgoing_pub(&mut mqtt); mqtt.last_puback = 0; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = &req.0 { + if let Request::Publish(publish, _) = &req { assert_eq!(publish.pkid, idx); } else { unreachable!() } } - mqtt.outgoing_pub = build_outgoing_pub(); + build_outgoing_pub(&mut mqtt); mqtt.last_puback = 6; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = &req.0 { + if let Request::Publish(publish, _) = &req { assert_eq!(publish.pkid, idx); } else { unreachable!() diff --git a/rumqttc/src/tokens.rs b/rumqttc/src/tokens.rs new file mode 100644 index 000000000..26f307f0c --- /dev/null +++ b/rumqttc/src/tokens.rs @@ -0,0 +1,118 @@ +use std::{ + fmt::Debug, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::sync::oneshot::{self, error::TryRecvError}; + +pub trait Reason: Debug + Send {} +impl Reason for T where T: Debug + Send {} + +#[derive(Debug, thiserror::Error)] +#[error("Broker rejected the request, reason: {0:?}")] +pub struct Rejection(Box); + +impl Rejection { + fn new(reason: R) -> Self { + Self(Box::new(reason)) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum TokenError { + #[error("Sender has nothing to send instantly")] + Waiting, + #[error("Sender side of channel was dropped")] + Disconnected, + #[error("Broker rejected the request, reason: {0:?}")] + Rejection(#[from] Rejection), +} + +pub type NoResponse = (); + +/// Resolves with [`Pkid`] used against packet when: +/// 1. Packet is acknowldged by the broker, e.g. QoS 1/2 Publish, Subscribe and Unsubscribe +/// 2. QoS 0 packet finishes processing in the [`EventLoop`] +pub struct Token { + rx: oneshot::Receiver>, +} + +impl Future for Token { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let polled = unsafe { self.map_unchecked_mut(|s| &mut s.rx) }.poll(cx); + + match polled { + Poll::Ready(Ok(Ok(p))) => Poll::Ready(Ok(p)), + Poll::Ready(Ok(Err(e))) => Poll::Ready(Err(TokenError::Rejection(e))), + Poll::Ready(Err(_)) => Poll::Ready(Err(TokenError::Disconnected)), + Poll::Pending => Poll::Pending, + } + } +} + +/// There is a type of token returned for each type of [`Request`] when it is created and +/// sent to the [`EventLoop`] for further processing from the [`Client`]/[`AsyncClient`]. +/// Some tokens such as those associated with the resolve with the `pkid` value used in the packet sent to the broker while other +/// tokens don't return such a value. +impl Token { + /// Blocks on the current thread and waits till the packet completes being handled. + /// + /// ## Errors + /// Returns [`TokenError::Disconnected`] if the [`EventLoop`] was dropped(usually), + /// [`TokenError::Rejection`] if the packet acknowledged but not accepted. + pub fn wait(self) -> Result { + self.rx + .blocking_recv() + .map_err(|_| TokenError::Disconnected)? + .map_err(|e| TokenError::Rejection(e)) + } + + /// Attempts to check if the packet handling has been completed, without blocking the current thread. + /// + /// ## Errors + /// Returns [`TokenError::Waiting`] if the packet wasn't acknowledged yet. + /// Multiple calls to this functions can fail with [`TokenError::Disconnected`] + /// if the promise has already been resolved. + pub fn check(&mut self) -> Result { + match self.rx.try_recv() { + Ok(r) => r.map_err(|e| TokenError::Rejection(e)), + Err(TryRecvError::Empty) => Err(TokenError::Waiting), + Err(TryRecvError::Closed) => Err(TokenError::Disconnected), + } + } +} + +#[derive(Debug)] +pub struct Resolver { + tx: oneshot::Sender>, +} + +impl Resolver { + pub fn new() -> (Self, Token) { + let (tx, rx) = oneshot::channel(); + + (Self { tx }, Token { rx }) + } + + #[cfg(test)] + pub fn mock() -> Self { + let (tx, _) = oneshot::channel(); + + Self { tx } + } + + pub fn resolve(self, resolved: T) { + if self.tx.send(Ok(resolved)).is_err() { + trace!("Promise was dropped") + } + } + + pub fn reject(self, reasons: R) { + if self.tx.send(Err(Rejection::new(reasons))).is_err() { + trace!("Promise was dropped") + } + } +} diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index a099f54d7..6f537f4d7 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -8,7 +8,8 @@ use super::mqttbytes::v5::{ }; use super::mqttbytes::QoS; use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; -use crate::{valid_filter, valid_topic, AckPromise, PromiseTx}; +use crate::tokens::{NoResponse, Resolver, Token}; +use crate::{valid_filter, valid_topic, Pkid}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -25,15 +26,15 @@ pub enum ClientError { TryRequest(Request), } -impl From)>> for ClientError { - fn from(e: SendError<(Request, Option)>) -> Self { - Self::Request(e.into_inner().0) +impl From> for ClientError { + fn from(e: SendError) -> Self { + Self::Request(e.into_inner()) } } -impl From)>> for ClientError { - fn from(e: TrySendError<(Request, Option)>) -> Self { - Self::TryRequest(e.into_inner().0) +impl From> for ClientError { + fn from(e: TrySendError) -> Self { + Self::TryRequest(e.into_inner()) } } @@ -46,7 +47,7 @@ impl From)>> for ClientError { /// from the broker, i.e. move ahead. #[derive(Clone, Debug)] pub struct AsyncClient { - request_tx: Sender<(Request, Option)>, + request_tx: Sender, } impl AsyncClient { @@ -66,7 +67,7 @@ impl AsyncClient { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_senders(request_tx: Sender<(Request, Option)>) -> AsyncClient { + pub fn from_senders(request_tx: Sender) -> AsyncClient { AsyncClient { request_tx } } @@ -78,24 +79,22 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.request_tx - .send_async((publish, Some(promise_tx))) - .await?; + self.request_tx.send_async(publish).await?; - Ok(promise) + Ok(token) } pub async fn publish_with_properties( @@ -105,7 +104,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -120,7 +119,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -136,22 +135,22 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } - self.request_tx.try_send((publish, Some(promise_tx)))?; + self.request_tx.try_send(publish)?; - Ok(promise) + Ok(token) } pub fn try_publish_with_properties( @@ -161,7 +160,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -175,7 +174,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -184,24 +183,26 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub async fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.request_tx.send_async((ack, None)).await?; + self.request_tx.send_async(ack).await?; } - Ok(()) + Ok(token) } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.request_tx.try_send((ack, None))?; + self.request_tx.try_send(ack)?; } - Ok(()) + Ok(token) } /// Sends a MQTT Publish to the `EventLoop` @@ -212,20 +213,18 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: Option, - ) -> Result + ) -> Result, ClientError> where S: Into, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); - self.request_tx - .send_async((publish, Some(promise_tx))) - .await?; + let publish = Request::Publish(publish, resolver); + self.request_tx.send_async(publish).await?; - Ok(promise) + Ok(token) } pub async fn publish_bytes_with_properties( @@ -235,7 +234,7 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: PublishProperties, - ) -> Result + ) -> Result, ClientError> where S: Into, { @@ -249,7 +248,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result + ) -> Result, ClientError> where S: Into, { @@ -263,18 +262,18 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.request_tx - .send_async((subscribe.into(), Some(promise_tx))) - .await?; + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } pub async fn subscribe_with_properties>( @@ -282,7 +281,7 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, Some(properties)).await } @@ -290,7 +289,7 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, None).await } @@ -300,17 +299,18 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } - self.request_tx - .try_send((subscribe.into(), Some(promise_tx)))?; + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } pub fn try_subscribe_with_properties>( @@ -318,7 +318,7 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_try_subscribe(topic, qos, Some(properties)) } @@ -326,7 +326,7 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result, ClientError> { self.handle_try_subscribe(topic, qos, None) } @@ -335,34 +335,34 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.request_tx - .send_async((subscribe.into(), Some(promise_tx))) - .await?; + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } pub async fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)).await } - pub async fn subscribe_many(&self, topics: T) -> Result + pub async fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -374,33 +374,34 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } - self.request_tx - .try_send((subscribe.into(), Some(promise_tx)))?; + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } pub fn try_subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { self.handle_try_subscribe_many(topics, Some(properties)) } - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -412,26 +413,24 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); - let request = Request::Unsubscribe(unsubscribe); - self.request_tx - .send_async((request, Some(promise_tx))) - .await?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.request_tx.send_async(request).await?; - Ok(promise) + Ok(token) } pub async fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_unsubscribe(topic, Some(properties)).await } - pub async fn unsubscribe>(&self, topic: S) -> Result { + pub async fn unsubscribe>(&self, topic: S) -> Result, ClientError> { self.handle_unsubscribe(topic, None).await } @@ -440,49 +439,54 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); - let request = Request::Unsubscribe(unsubscribe); - self.request_tx.try_send((request, Some(promise_tx)))?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.request_tx.try_send(request)?; - Ok(promise) + Ok(token) } pub fn try_unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_try_unsubscribe(topic, Some(properties)) } - pub fn try_unsubscribe>(&self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { self.handle_try_unsubscribe(topic, None) } /// Sends a MQTT disconnect to the `EventLoop` - pub async fn disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; - self.request_tx.send_async((request, None)).await?; + pub async fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.request_tx.send_async(request).await?; - Ok(()) + Ok(token) } /// Attempts to send a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; - self.request_tx.try_send((request, None))?; + pub fn try_disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.request_tx.try_send(request)?; - Ok(()) + Ok(token) } } -fn get_ack_req(publish: &Publish) -> Option { +fn get_ack_req(publish: &Publish, resolver: Resolver<()>) -> Option { let ack = match publish.qos { - QoS::AtMostOnce => return None, - QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid, None)), - QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid, None)), + QoS::AtMostOnce => { + resolver.resolve(()); + return None; + } + QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid, None), resolver), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid, None), resolver), }; Some(ack) } @@ -523,7 +527,7 @@ impl Client { /// /// This is mostly useful for creating a test instance where you can /// listen on the corresponding receiver. - pub fn from_sender(request_tx: Sender<(Request, Option)>) -> Client { + pub fn from_sender(request_tx: Sender) -> Client { Client { client: AsyncClient::from_senders(request_tx), } @@ -537,22 +541,22 @@ impl Client { retain: bool, payload: P, properties: Option, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } - self.client.request_tx.send((publish, Some(promise_tx)))?; + self.client.request_tx.send(publish)?; - Ok(promise) + Ok(token) } pub fn publish_with_properties( @@ -562,7 +566,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -576,7 +580,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -591,7 +595,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -606,7 +610,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result + ) -> Result, ClientError> where S: Into, P: Into, @@ -615,18 +619,19 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { - self.client.request_tx.send((ack, None))?; + self.client.request_tx.send(ack)?; } - Ok(()) + Ok(token) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { self.client.try_ack(publish) } @@ -636,18 +641,18 @@ impl Client { topic: S, qos: QoS, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.client - .request_tx - .send((subscribe.into(), Some(promise_tx)))?; + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } pub fn subscribe_with_properties>( @@ -655,7 +660,7 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, Some(properties)) } @@ -663,7 +668,7 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, None) } @@ -673,7 +678,7 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.client .try_subscribe_with_properties(topic, qos, properties) } @@ -682,7 +687,7 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result { + ) -> Result, ClientError> { self.client.try_subscribe(topic, qos) } @@ -691,34 +696,34 @@ impl Client { &self, topics: T, properties: Option, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { - let (promise_tx, promise) = PromiseTx::new(); + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } - self.client - .request_tx - .send((subscribe.into(), Some(promise_tx)))?; + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } pub fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)) } - pub fn subscribe_many(&self, topics: T) -> Result + pub fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -729,7 +734,7 @@ impl Client { &self, topics: T, properties: SubscribeProperties, - ) -> Result + ) -> Result, ClientError> where T: IntoIterator, { @@ -737,7 +742,7 @@ impl Client { .try_subscribe_many_with_properties(topics, properties) } - pub fn try_subscribe_many(&self, topics: T) -> Result + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -749,24 +754,24 @@ impl Client { &self, topic: S, properties: Option, - ) -> Result { - let (promise_tx, promise) = PromiseTx::new(); + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); - let request = Request::Unsubscribe(unsubscribe); - self.client.request_tx.send((request, Some(promise_tx)))?; + let request = Request::Unsubscribe(unsubscribe, resolver); + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } pub fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.handle_unsubscribe(topic, Some(properties)) } - pub fn unsubscribe>(&self, topic: S) -> Result { + pub fn unsubscribe>(&self, topic: S) -> Result, ClientError> { self.handle_unsubscribe(topic, None) } @@ -775,26 +780,26 @@ impl Client { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result { + ) -> Result, ClientError> { self.client .try_unsubscribe_with_properties(topic, properties) } - pub fn try_unsubscribe>(&self, topic: S) -> Result { + pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result { - let (promise_tx, promise) = PromiseTx::new(); - let request = Request::Disconnect; - self.client.request_tx.send((request, Some(promise_tx)))?; + pub fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); + self.client.request_tx.send(request)?; - Ok(promise) + Ok(token) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { + pub fn try_disconnect(&self) -> Result, ClientError> { self.client.try_disconnect() } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index b2c2fc506..ea361b4eb 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -3,7 +3,6 @@ use super::mqttbytes::v5::*; use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport}; use crate::eventloop::socket_connect; use crate::framed::AsyncReadWrite; -use crate::PromiseTx; use flume::{bounded, Receiver, Sender}; use tokio::select; @@ -74,11 +73,11 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver<(Request, Option)>, + requests_rx: Receiver, /// Requests handle to send requests - pub(crate) requests_tx: Sender<(Request, Option)>, + pub(crate) requests_tx: Sender, /// Pending packets from last session - pub pending: VecDeque<(Request, Option)>, + pub pending: VecDeque, /// Network connection to the broker network: Option, /// Keep alive time @@ -129,9 +128,9 @@ impl EventLoop { // drain requests from channel which weren't yet received let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect(); - requests_in_channel.retain(|(request, _)| { + requests_in_channel.retain(|request| { match request { - Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack + Request::PubAck(..) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, } }); @@ -224,8 +223,8 @@ impl EventLoop { &self.requests_rx, self.options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok((request, tx)) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request, tx)? { + Ok(request) => { + if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { network.write(outgoing).await?; } network.flush().await?; @@ -246,7 +245,7 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq, None)? { + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { network.write(outgoing).await?; } network.flush().await?; @@ -256,10 +255,10 @@ impl EventLoop { } async fn next_request( - pending: &mut VecDeque<(Request, Option)>, - rx: &Receiver<(Request, Option)>, + pending: &mut VecDeque, + rx: &Receiver, pending_throttle: Duration, - ) -> Result<(Request, Option), ConnectionError> { + ) -> Result { if !pending.is_empty() { time::sleep(pending_throttle).await; // We must call .next() AFTER sleep() otherwise .next() would diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 6e0e43931..22b1942c2 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -14,8 +14,9 @@ mod framed; pub mod mqttbytes; mod state; -use crate::Outgoing; +use crate::tokens::Resolver; use crate::{NetworkOptions, Transport}; +use crate::{Outgoing, Pkid}; use mqttbytes::v5::*; @@ -33,26 +34,16 @@ pub type Incoming = Packet; /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Debug)] pub enum Request { - Publish(Publish), - PubAck(PubAck), - PubRec(PubRec), - PubComp(PubComp), - PubRel(PubRel), + Publish(Publish, Resolver), + PubAck(PubAck, Resolver<()>), + PubRec(PubRec, Resolver<()>), + PubRel(PubRel, Resolver), + Subscribe(Subscribe, Resolver), + Unsubscribe(Unsubscribe, Resolver), + Disconnect(Resolver<()>), PingReq, - PingResp, - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - Disconnect, -} - -impl From for Request { - fn from(subscribe: Subscribe) -> Self { - Self::Subscribe(subscribe) - } } #[cfg(feature = "websocket")] diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 97a6def30..6f9e430c3 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,13 +1,17 @@ -use crate::PromiseTx; - -use super::mqttbytes::v5::{ - ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, - PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, - SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, +use crate::{tokens::Resolver, Pkid}; + +use super::{ + mqttbytes::{ + self, + v5::{ + ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, + PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, + Publish, SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, + }, + Error as MqttError, QoS, + }, + Event, Incoming, Outgoing, Request, }; -use super::mqttbytes::{self, Error as MqttError, QoS}; - -use super::{Event, Incoming, Outgoing, Request}; use bytes::Bytes; use fixedbitset::FixedBitSet; @@ -99,7 +103,7 @@ pub struct MqttState { /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, /// Last collision due to broker not acking in order - pub collision: Option<(Publish, Option)>, + pub collision: Option<(Publish, Resolver)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -112,8 +116,12 @@ pub struct MqttState { pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, - /// Waiters for publish/subscribe/unsubscribe acknowledgements - ack_waiter: Vec>, + /// Waiters for publish acknowledgements + pub_ack_waiter: HashMap>, + /// Waiters for subscribe acknowledgements + sub_ack_waiter: HashMap>, + /// Waiters for unsubscribe acknowledgements + unsub_ack_waiter: HashMap>, } impl MqttState { @@ -141,27 +149,29 @@ impl MqttState { broker_topic_alias_max: 0, max_outgoing_inflight: max_inflight, max_outgoing_inflight_upper_limit: max_inflight, - ack_waiter: (0..max_inflight as usize + 1).map(|_| None).collect(), + pub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + sub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + unsub_ack_waiter: HashMap::with_capacity(max_inflight as usize), } } /// Returns inflight outgoing packets and clears internal queues - pub fn clean(&mut self) -> Vec<(Request, Option)> { + pub fn clean(&mut self) -> Vec { let mut pending = Vec::with_capacity(100); // remove and collect pending publishes for publish in self.outgoing_pub.iter_mut() { if let Some(publish) = publish.take() { - let tx = self.ack_waiter[publish.pkid as usize].take(); - let request = Request::Publish(publish); - pending.push((request, tx)); + let resolver = self.pub_ack_waiter.remove(&publish.pkid).unwrap(); + let request = Request::Publish(publish, resolver); + pending.push(request); } } // remove and collect pending releases for pkid in self.outgoing_rel.ones() { - let tx = self.ack_waiter[pkid].take(); - let request = Request::PubRel(PubRel::new(pkid as u16, None)); - pending.push((request, tx)); + let resolver = self.pub_ack_waiter.remove(&(pkid as u16)).unwrap(); + let request = Request::PubRel(PubRel::new(pkid as u16, None), resolver); + pending.push(request); } self.outgoing_rel.clear(); @@ -183,20 +193,29 @@ impl MqttState { pub fn handle_outgoing_packet( &mut self, request: Request, - tx: Option, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish, tx)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, tx)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, tx)?, + Request::Publish(publish, resolver) => self.outgoing_publish(publish, resolver)?, + Request::PubRel(pubrel, resolver) => self.outgoing_pubrel(pubrel, resolver)?, + Request::Subscribe(subscribe, resolver) => { + self.outgoing_subscribe(subscribe, resolver)? + } + Request::Unsubscribe(unsubscribe, resolver) => { + self.outgoing_unsubscribe(unsubscribe, resolver)? + } Request::PingReq => self.outgoing_ping()?, - Request::Disconnect => { + Request::Disconnect(resolver) => { + resolver.resolve(()); self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? } - Request::PubAck(puback) => self.outgoing_puback(puback)?, - Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, - _ => unimplemented!(), + Request::PubAck(puback, resolver) => { + resolver.resolve(()); + self.outgoing_puback(puback)? + } + Request::PubRec(pubrec, resolver) => { + resolver.resolve(()); + self.outgoing_pubrec(pubrec)? + } }; self.last_outgoing = Instant::now(); @@ -209,11 +228,11 @@ impl MqttState { /// be forwarded to user and Pubck packet will be written to network pub fn handle_incoming_packet( &mut self, - mut packet: Incoming, + packet: Incoming, ) -> Result, StateError> { self.events.push_back(Event::Incoming(packet.to_owned())); - let outgoing = match &mut packet { + let outgoing = match packet { Incoming::PingResp(_) => self.handle_incoming_pingresp()?, Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, @@ -239,18 +258,11 @@ impl MqttState { self.outgoing_disconnect(DisconnectReasonCode::ProtocolError) } - fn is_pkid_of_publish(&self, pkid: u16) -> bool { - self.outgoing_pub[pkid as usize].is_some() || self.outgoing_rel.contains(pkid as usize) - } - - fn handle_incoming_suback( - &mut self, - suback: &mut SubAck, - ) -> Result, StateError> { + fn handle_incoming_suback(&mut self, suback: SubAck) -> Result, StateError> { // Expected ack for a subscribe packet, not a publish packet - if self.is_pkid_of_publish(suback.pkid) { + let Some(resolver) = self.sub_ack_waiter.remove(&suback.pkid) else { return Err(StateError::Unsolicited(suback.pkid)); - } + }; for reason in suback.return_codes.iter() { match reason { @@ -263,16 +275,14 @@ impl MqttState { } } - if let Some(tx) = self.ack_waiter[suback.pkid as usize].take() { - if suback - .return_codes - .iter() - .all(|r| matches!(r, SubscribeReasonCode::Success(_))) - { - tx.resolve(suback.pkid); - } else { - tx.fail(format!("{:?}", suback.return_codes)); - } + if suback + .return_codes + .iter() + .all(|r| matches!(r, SubscribeReasonCode::Success(_))) + { + resolver.resolve(suback.pkid); + } else { + resolver.reject(suback.return_codes); } Ok(None) @@ -280,12 +290,11 @@ impl MqttState { fn handle_incoming_unsuback( &mut self, - unsuback: &mut UnsubAck, + unsuback: UnsubAck, ) -> Result, StateError> { - // Expected ack for a unsubscribe packet, not a publish packet - if self.is_pkid_of_publish(unsuback.pkid) { + let Some(resolver) = self.unsub_ack_waiter.remove(&unsuback.pkid) else { return Err(StateError::Unsolicited(unsuback.pkid)); - } + }; for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { @@ -293,25 +302,20 @@ impl MqttState { } } - if let Some(tx) = self.ack_waiter[unsuback.pkid as usize].take() { - if unsuback - .reasons - .iter() - .all(|r| matches!(r, UnsubAckReason::Success)) - { - tx.resolve(unsuback.pkid); - } else { - tx.fail(format!("{:?}", unsuback.reasons)); - } + if unsuback + .reasons + .iter() + .all(|r| matches!(r, UnsubAckReason::Success)) + { + resolver.resolve(unsuback.pkid); + } else { + resolver.reject(unsuback.reasons); } Ok(None) } - fn handle_incoming_connack( - &mut self, - connack: &mut ConnAck, - ) -> Result, StateError> { + fn handle_incoming_connack(&mut self, connack: ConnAck) -> Result, StateError> { if connack.code != ConnectReturnCode::Success { return Err(StateError::ConnFail { reason: connack.code, @@ -335,7 +339,7 @@ impl MqttState { fn handle_incoming_disconn( &mut self, - disconn: &mut Disconnect, + disconn: Disconnect, ) -> Result, StateError> { let reason_code = disconn.reason_code; let reason_string = if let Some(props) = &disconn.properties { @@ -353,7 +357,7 @@ impl MqttState { /// in case of QoS1 and Replys rec in case of QoS while also storing the message fn handle_incoming_publish( &mut self, - publish: &mut Publish, + mut publish: Publish, ) -> Result, StateError> { let qos = publish.qos; @@ -396,24 +400,22 @@ impl MqttState { } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(puback.pkid as usize) - .ok_or(StateError::Unsolicited(puback.pkid))?; - - if publish.take().is_none() { + fn handle_incoming_puback(&mut self, puback: PubAck) -> Result, StateError> { + let Some(resolver) = self.pub_ack_waiter.remove(&puback.pkid) else { error!("Unsolicited puback packet: {:?}", puback.pkid); return Err(StateError::Unsolicited(puback.pkid)); - } + }; - if let Some(tx) = self.ack_waiter[puback.pkid as usize].take() { - // Resolve promise for QoS 1 - if puback.reason == PubAckReason::Success { - tx.resolve(puback.pkid); - } else { - tx.fail(format!("{:?}", puback.reason)); - } + self.outgoing_pub + .get_mut(puback.pkid as usize) + .ok_or(StateError::Unsolicited(puback.pkid))? + .take(); + + // Resolve promise for QoS 1 + if puback.reason == PubAckReason::Success { + resolver.resolve(puback.pkid); + } else { + resolver.reject(puback.reason); } self.inflight -= 1; @@ -428,7 +430,7 @@ impl MqttState { return Ok(None); } - if let Some((publish, tx)) = self.check_collision(puback.pkid) { + if let Some((publish, resolver)) = self.check_collision(puback.pkid) { self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); self.inflight += 1; @@ -436,7 +438,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; - self.ack_waiter[puback.pkid as usize] = tx; + self.pub_ack_waiter.insert(puback.pkid, resolver); return Ok(Some(Packet::Publish(publish))); } @@ -444,7 +446,7 @@ impl MqttState { Ok(None) } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -473,7 +475,7 @@ impl MqttState { Ok(Some(Packet::PubRel(PubRel::new(pubrec.pkid, None)))) } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { if !self.incoming_pub.contains(pubrel.pkid as usize) { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); return Err(StateError::Unsolicited(pubrel.pkid)); @@ -494,30 +496,30 @@ impl MqttState { Ok(Some(Packet::PubComp(PubComp::new(pubrel.pkid, None)))) } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { - if !self.outgoing_rel.contains(pubcomp.pkid as usize) { + fn handle_incoming_pubcomp(&mut self, pubcomp: PubComp) -> Result, StateError> { + let Some(resolver) = self.pub_ack_waiter.remove(&pubcomp.pkid) else { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); - } + }; - if let Some(tx) = self.ack_waiter[pubcomp.pkid as usize].take() { - // Resolve promise for QoS 2 - if pubcomp.reason == PubCompReason::Success { - tx.resolve(pubcomp.pkid); - } else { - tx.fail(format!("{:?}", pubcomp.reason)); - } + // Resolve promise for QoS 2 + if pubcomp.reason == PubCompReason::Success { + resolver.resolve(pubcomp.pkid); + } else { + resolver.reject(pubcomp.reason); } self.outgoing_rel.set(pubcomp.pkid as usize, false); - let outgoing = self.check_collision(pubcomp.pkid).map(|(publish, tx)| { - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - self.ack_waiter[pubcomp.pkid as usize] = tx; + let outgoing = self + .check_collision(pubcomp.pkid) + .map(|(publish, resolver)| { + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + self.pub_ack_waiter.insert(pubcomp.pkid, resolver); - Packet::Publish(publish) - }); + Packet::Publish(publish) + }); if pubcomp.reason != PubCompReason::Success { warn!( @@ -541,7 +543,7 @@ impl MqttState { fn outgoing_publish( &mut self, mut publish: Publish, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { @@ -556,7 +558,7 @@ impl MqttState { .is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some((publish, tx)); + self.collision = Some((publish, resolver)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -592,21 +594,27 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); - match (publish.qos, tx) { - (QoS::AtMostOnce, Some(tx)) => tx.resolve(0), - (_, tx) => self.ack_waiter[publish.pkid as usize] = tx, + if publish.qos == QoS::AtMostOnce { + resolver.resolve(0); + } else { + self.pub_ack_waiter.insert(publish.pkid, resolver); } Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { + fn outgoing_pubrel( + &mut self, + pubrel: PubRel, + resolver: Resolver, + ) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); + self.pub_ack_waiter.insert(pubrel.pkid, resolver); Ok(Some(Packet::PubRel(PubRel::new(pubrel.pkid, None)))) } @@ -662,7 +670,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -679,7 +687,7 @@ impl MqttState { let pkid = subscription.pkid; let event = Event::Outgoing(Outgoing::Subscribe(pkid)); self.events.push_back(event); - self.ack_waiter[subscription.pkid as usize] = tx; + self.sub_ack_waiter.insert(subscription.pkid, resolver); Ok(Some(Packet::Subscribe(subscription))) } @@ -687,7 +695,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - tx: Option, + resolver: Resolver, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -700,7 +708,7 @@ impl MqttState { let pkid = unsub.pkid; let event = Event::Outgoing(Outgoing::Unsubscribe(pkid)); self.events.push_back(event); - self.ack_waiter[unsub.pkid as usize] = tx; + self.unsub_ack_waiter.insert(unsub.pkid, resolver); Ok(Some(Packet::Unsubscribe(unsub))) } @@ -716,7 +724,7 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } - fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Option)> { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Resolver)> { if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); @@ -763,6 +771,8 @@ impl MqttState { #[cfg(test)] mod test { + use crate::tokens::Resolver; + use super::mqttbytes::v5::*; use super::mqttbytes::*; use super::{Event, Incoming, Outgoing, Request}; @@ -816,7 +826,9 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -824,12 +836,15 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -837,12 +852,15 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -854,27 +872,31 @@ mod test { // QoS2 publish let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be set back down to 0, since we hit the limit - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); // This should cause a collition - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 2); assert!(mqtt.collision.is_some()); - mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap(); - mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(1, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 1); // Now there should be space in the outgoing queue - mqtt.outgoing_publish(publish.clone(), None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); } @@ -884,13 +906,13 @@ mod test { let mut mqtt = build_mqttstate(); // QoS0, 1, 2 Publishes - let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&mut publish1).unwrap(); - mqtt.handle_incoming_publish(&mut publish2).unwrap(); - mqtt.handle_incoming_publish(&mut publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); // only qos2 publish should be add to queue assert!(mqtt.incoming_pub.contains(3)); @@ -901,13 +923,13 @@ mod test { let mut mqtt = build_mqttstate(); // QoS0, 1, 2 Publishes - let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&mut publish1).unwrap(); - mqtt.handle_incoming_publish(&mut publish2).unwrap(); - mqtt.handle_incoming_publish(&mut publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] { assert_eq!(pkid, 2); @@ -928,13 +950,13 @@ mod test { mqtt.manual_acks = true; // QoS0, 1, 2 Publishes - let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&mut publish1).unwrap(); - mqtt.handle_incoming_publish(&mut publish2).unwrap(); - mqtt.handle_incoming_publish(&mut publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); assert!(mqtt.incoming_pub.contains(3)); assert!(mqtt.events.is_empty()); @@ -943,9 +965,9 @@ mod test { #[test] fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { let mut mqtt = build_mqttstate(); - let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() { + match mqtt.handle_incoming_publish(publish).unwrap().unwrap() { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } @@ -958,14 +980,16 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1, None).unwrap(); - mqtt.outgoing_publish(publish2, None).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish1, resolver).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish2, resolver).unwrap(); assert_eq!(mqtt.inflight, 2); - mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(1, None)).unwrap(); assert_eq!(mqtt.inflight, 1); - mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 0); assert!(mqtt.outgoing_pub[1].is_none()); @@ -977,7 +1001,7 @@ mod test { let mut mqtt = build_mqttstate(); let got = mqtt - .handle_incoming_puback(&PubAck::new(101, None)) + .handle_incoming_puback(PubAck::new(101, None)) .unwrap_err(); match got { @@ -993,10 +1017,12 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1, None); - let _publish_out = mqtt.outgoing_publish(publish2, None); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish1, resolver); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish2, resolver); - mqtt.handle_incoming_pubrec(&PubRec::new(2, None)).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 @@ -1012,13 +1038,14 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - match mqtt.outgoing_publish(publish, None).unwrap().unwrap() { + let resolver = Resolver::mock(); + match mqtt.outgoing_publish(publish, resolver).unwrap().unwrap() { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } match mqtt - .handle_incoming_pubrec(&PubRec::new(1, None)) + .handle_incoming_pubrec(PubRec::new(1, None)) .unwrap() .unwrap() { @@ -1030,15 +1057,15 @@ mod test { #[test] fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { let mut mqtt = build_mqttstate(); - let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() { + match mqtt.handle_incoming_publish(publish).unwrap().unwrap() { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } match mqtt - .handle_incoming_pubrel(&PubRel::new(1, None)) + .handle_incoming_pubrel(PubRel::new(1, None)) .unwrap() .unwrap() { @@ -1052,11 +1079,11 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish, None).unwrap(); - mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(1, None)).unwrap(); - mqtt.handle_incoming_pubcomp(&PubComp::new(1, None)) - .unwrap(); + mqtt.handle_incoming_pubcomp(PubComp::new(1, None)).unwrap(); assert_eq!(mqtt.inflight, 0); } @@ -1067,7 +1094,8 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish), None) + let resolver = Resolver::mock(); + mqtt.handle_outgoing_packet(Request::Publish(publish, resolver)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None))) .unwrap(); From d67e919628e98534853688b88d0cf1665b16cee5 Mon Sep 17 00:00:00 2001 From: swanandx <73115739+swanandx@users.noreply.github.com> Date: Fri, 15 Nov 2024 23:36:17 +0530 Subject: [PATCH 27/30] fix: imports --- rumqttc/src/lib.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index b707782db..f669508e2 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -119,7 +119,10 @@ mod tls; mod websockets; #[cfg(feature = "websocket")] -use std::future::IntoFuture; +use std::{ + future::{Future, IntoFuture}, + pin::Pin, +}; #[cfg(feature = "websocket")] type RequestModifierFn = Arc< From 8be8afe3411d520ac56905d2dd9431acf1374035 Mon Sep 17 00:00:00 2001 From: swanandx <73115739+swanandx@users.noreply.github.com> Date: Fri, 15 Nov 2024 23:38:33 +0530 Subject: [PATCH 28/30] fix: clippy lint --- rumqttc/src/tokens.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rumqttc/src/tokens.rs b/rumqttc/src/tokens.rs index 26f307f0c..4d30cac52 100644 --- a/rumqttc/src/tokens.rs +++ b/rumqttc/src/tokens.rs @@ -67,7 +67,7 @@ impl Token { self.rx .blocking_recv() .map_err(|_| TokenError::Disconnected)? - .map_err(|e| TokenError::Rejection(e)) + .map_err(TokenError::Rejection) } /// Attempts to check if the packet handling has been completed, without blocking the current thread. @@ -78,7 +78,7 @@ impl Token { /// if the promise has already been resolved. pub fn check(&mut self) -> Result { match self.rx.try_recv() { - Ok(r) => r.map_err(|e| TokenError::Rejection(e)), + Ok(r) => r.map_err(TokenError::Rejection), Err(TryRecvError::Empty) => Err(TokenError::Waiting), Err(TryRecvError::Closed) => Err(TokenError::Disconnected), } From 919c981fb30e22c0873e8e25d4ad0914ad11c326 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sun, 23 Feb 2025 04:17:10 +0530 Subject: [PATCH 29/30] fix: token interfaces (#946) --- rumqttc/examples/ack_promise.rs | 8 +- rumqttc/examples/ack_promise_sync.rs | 8 +- rumqttc/examples/ack_promise_v5.rs | 8 +- rumqttc/src/client.rs | 58 ++++++----- rumqttc/src/lib.rs | 27 ++++-- rumqttc/src/state.rs | 139 ++++++++++++++------------- rumqttc/src/tokens.rs | 42 ++------ rumqttc/src/v5/client.rs | 123 +++++++++++++----------- rumqttc/src/v5/mod.rs | 30 ++++-- rumqttc/src/v5/state.rs | 79 ++++++--------- rumqttc/tests/reliability.rs | 11 ++- 11 files changed, 272 insertions(+), 261 deletions(-) diff --git a/rumqttc/examples/ack_promise.rs b/rumqttc/examples/ack_promise.rs index 55ff7493e..f4c837b29 100644 --- a/rumqttc/examples/ack_promise.rs +++ b/rumqttc/examples/ack_promise.rs @@ -31,7 +31,7 @@ async fn main() -> Result<(), Box> { .unwrap() .await { - Ok(pkid) => println!("Acknowledged Sub({pkid})"), + Ok(pkid) => println!("Acknowledged Sub({pkid:?})"), Err(e) => println!("Subscription failed: {e:?}"), } @@ -46,7 +46,7 @@ async fn main() -> Result<(), Box> { .unwrap() .await { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Ok(ack) => println!("Acknowledged Pub({ack:?})"), Err(e) => println!("Publish failed: {e:?}"), } } @@ -66,14 +66,14 @@ async fn main() -> Result<(), Box> { while let Some(Ok(res)) = set.join_next().await { match res { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Ok(ack) => println!("Acknowledged Pub({ack:?})"), Err(e) => println!("Publish failed: {e:?}"), } } // Unsubscribe and wait for broker acknowledgement match client.unsubscribe("hello/world").await.unwrap().await { - Ok(pkid) => println!("Acknowledged Unsub({pkid})"), + Ok(ack) => println!("Acknowledged Unsub({ack:?})"), Err(e) => println!("Unsubscription failed: {e:?}"), } diff --git a/rumqttc/examples/ack_promise_sync.rs b/rumqttc/examples/ack_promise_sync.rs index 506ee486c..1eaa20e0d 100644 --- a/rumqttc/examples/ack_promise_sync.rs +++ b/rumqttc/examples/ack_promise_sync.rs @@ -28,7 +28,7 @@ fn main() -> Result<(), Box> { .unwrap() .wait() { - Ok(pkid) => println!("Acknowledged Sub({pkid})"), + Ok(pkid) => println!("Acknowledged Sub({pkid:?})"), Err(e) => println!("Subscription failed: {e:?}"), } @@ -42,7 +42,7 @@ fn main() -> Result<(), Box> { .unwrap() .wait() { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Ok(ack) => println!("Acknowledged Pub({ack:?})"), Err(e) => println!("Publish failed: {e:?}"), } } @@ -83,14 +83,14 @@ fn main() -> Result<(), Box> { while let Ok(res) = rx.recv() { match res { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Ok(ack) => println!("Acknowledged Pub({ack:?})"), Err(e) => println!("Publish failed: {e:?}"), } } // Unsubscribe and wait for broker acknowledgement match client.unsubscribe("hello/world").unwrap().wait() { - Ok(pkid) => println!("Acknowledged Unsub({pkid})"), + Ok(ack) => println!("Acknowledged Unsub({ack:?})"), Err(e) => println!("Unsubscription failed: {e:?}"), } diff --git a/rumqttc/examples/ack_promise_v5.rs b/rumqttc/examples/ack_promise_v5.rs index 8873cf6af..de2fdf566 100644 --- a/rumqttc/examples/ack_promise_v5.rs +++ b/rumqttc/examples/ack_promise_v5.rs @@ -31,7 +31,7 @@ async fn main() -> Result<(), Box> { .unwrap() .await { - Ok(pkid) => println!("Acknowledged Sub({pkid})"), + Ok(pkid) => println!("Acknowledged Sub({pkid:?})"), Err(e) => println!("Subscription failed: {e:?}"), } @@ -46,7 +46,7 @@ async fn main() -> Result<(), Box> { .unwrap() .await { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Ok(pkid) => println!("Acknowledged Pub({pkid:?})"), Err(e) => println!("Publish failed: {e:?}"), } } @@ -66,14 +66,14 @@ async fn main() -> Result<(), Box> { while let Some(Ok(res)) = set.join_next().await { match res { - Ok(pkid) => println!("Acknowledged Pub({pkid})"), + Ok(pkid) => println!("Acknowledged Pub({pkid:?})"), Err(e) => println!("Publish failed: {e:?}"), } } // Unsubscribe and wait for broker acknowledgement match client.unsubscribe("hello/world").await.unwrap().await { - Ok(pkid) => println!("Acknowledged Unsub({pkid})"), + Ok(pkid) => println!("Acknowledged Unsub({pkid:?})"), Err(e) => println!("Unsubscription failed: {e:?}"), } diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index ffab94f0c..404b5230f 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -5,7 +5,8 @@ use std::time::Duration; use crate::mqttbytes::{v4::*, QoS}; use crate::tokens::{NoResponse, Resolver, Token}; use crate::{ - valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, Pkid, Request, + valid_filter, valid_topic, AckOfAck, AckOfPub, ConnectionError, Event, EventLoop, MqttOptions, + Request, }; use bytes::Bytes; @@ -75,7 +76,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, V: Into>, @@ -100,7 +101,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, V: Into>, @@ -119,7 +120,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result, ClientError> { + pub async fn ack(&self, publish: &Publish) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { @@ -130,7 +131,7 @@ impl AsyncClient { } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { @@ -147,7 +148,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, { @@ -165,7 +166,7 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result, ClientError> { + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); @@ -184,7 +185,7 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result, ClientError> { + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); let is_valid = subscribe_has_valid_filters(&subscribe); @@ -198,7 +199,7 @@ impl AsyncClient { } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub async fn subscribe_many(&self, topics: T) -> Result, ClientError> + pub async fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -216,7 +217,7 @@ impl AsyncClient { } /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -233,7 +234,10 @@ impl AsyncClient { } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub async fn unsubscribe>(&self, topic: S) -> Result, ClientError> { + pub async fn unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe, resolver); @@ -243,7 +247,10 @@ impl AsyncClient { } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { + pub fn try_unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe, resolver); @@ -271,10 +278,10 @@ impl AsyncClient { } } -fn get_ack_req(publish: &Publish, resolver: Resolver<()>) -> Option { +fn get_ack_req(publish: &Publish, resolver: Resolver) -> Option { let ack = match publish.qos { QoS::AtMostOnce => { - resolver.resolve(()); + resolver.resolve(AckOfAck::None); return None; } QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid), resolver), @@ -331,7 +338,7 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, V: Into>, @@ -355,7 +362,7 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, V: Into>, @@ -364,7 +371,7 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result, ClientError> { + pub fn ack(&self, publish: &Publish) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { @@ -375,7 +382,7 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { self.client.try_ack(publish) } @@ -384,7 +391,7 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result, ClientError> { + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); let is_valid = subscribe_has_valid_filters(&subscribe); @@ -402,12 +409,12 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn subscribe_many(&self, topics: T) -> Result, ClientError> + pub fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -424,7 +431,7 @@ impl Client { Ok(token) } - pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -432,7 +439,7 @@ impl Client { } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn unsubscribe>(&self, topic: S) -> Result, ClientError> { + pub fn unsubscribe>(&self, topic: S) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); let request = Request::Unsubscribe(unsubscribe, resolver); @@ -442,7 +449,10 @@ impl Client { } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { + pub fn try_unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { self.client.try_unsubscribe(topic) } diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index f669508e2..8cf870c80 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -160,6 +160,21 @@ pub use proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; +/// Used to encapsulate all publish/pubrec acknowledgements in v4 +#[derive(Debug, PartialEq)] +pub enum AckOfPub { + PubAck(PubAck), + PubComp(PubComp), + None, +} + +/// Used to encapsulate all ack/pubrel acknowledgements in v4 +#[derive(Debug)] +pub enum AckOfAck { + None, + PubRel(PubRel), +} + /// Current outgoing activity on the eventloop #[derive(Debug, Clone, PartialEq, Eq)] pub enum Outgoing { @@ -191,12 +206,12 @@ pub enum Outgoing { /// handled one by one. #[derive(Debug)] pub enum Request { - Publish(Publish, Resolver), - PubAck(PubAck, Resolver<()>), - PubRec(PubRec, Resolver<()>), - PubRel(PubRel, Resolver), - Subscribe(Subscribe, Resolver), - Unsubscribe(Unsubscribe, Resolver), + Publish(Publish, Resolver), + PubAck(PubAck, Resolver), + PubRec(PubRec, Resolver), + PubRel(PubRel, Resolver), + Subscribe(Subscribe, Resolver), + Unsubscribe(Unsubscribe, Resolver), Disconnect(Resolver<()>), PingReq, } diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 4698996a7..b1f918d09 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,5 +1,5 @@ -use crate::Pkid; use crate::{tokens::Resolver, Event, Incoming, Outgoing, Request}; +use crate::{AckOfAck, AckOfPub, Pkid}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; @@ -68,17 +68,19 @@ pub struct MqttState { /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, /// Last collision due to broker not acking in order - pub collision: Option<(Publish, Resolver)>, + pub collision: Option<(Publish, Resolver)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, /// Waiters for publish acknowledgements - pub_ack_waiter: HashMap>, + pub_ack_waiter: HashMap>, + /// Waiters for PubRel, qos 2 + pub_rel_waiter: HashMap>, /// Waiters for subscribe acknowledgements - sub_ack_waiter: HashMap>, + sub_ack_waiter: HashMap>, /// Waiters for unsubscribe acknowledgements - unsub_ack_waiter: HashMap>, + unsub_ack_waiter: HashMap>, } impl MqttState { @@ -104,6 +106,7 @@ impl MqttState { events: VecDeque::with_capacity(100), manual_acks, pub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + pub_rel_waiter: HashMap::with_capacity(max_inflight as usize), sub_ack_waiter: HashMap::with_capacity(max_inflight as usize), unsub_ack_waiter: HashMap::with_capacity(max_inflight as usize), } @@ -167,13 +170,10 @@ impl MqttState { self.outgoing_disconnect()? } Request::PubAck(puback, resolver) => { - resolver.resolve(()); + resolver.resolve(AckOfAck::None); self.outgoing_puback(puback)? } - Request::PubRec(pubrec, resolver) => { - resolver.resolve(()); - self.outgoing_pubrec(pubrec)? - } + Request::PubRec(pubrec, resolver) => self.outgoing_pubrec(pubrec, resolver)?, }; self.last_outgoing = Instant::now(); @@ -213,15 +213,8 @@ impl MqttState { let Some(resolver) = self.sub_ack_waiter.remove(&suback.pkid) else { return Err(StateError::Unsolicited(suback.pkid)); }; - if suback - .return_codes - .iter() - .all(|r| matches!(r, SubscribeReasonCode::Success(_))) - { - resolver.resolve(suback.pkid); - } else { - resolver.reject(suback.return_codes); - } + + resolver.resolve(suback); Ok(None) } @@ -234,7 +227,7 @@ impl MqttState { return Err(StateError::Unsolicited(unsuback.pkid)); }; - resolver.resolve(unsuback.pkid); + resolver.resolve(unsuback); Ok(None) } @@ -259,7 +252,8 @@ impl MqttState { if !self.manual_acks { let pubrec = PubRec::new(pkid); - return self.outgoing_pubrec(pubrec); + let (resolver, _) = Resolver::new(); + return self.outgoing_pubrec(pubrec, resolver); } Ok(None) } @@ -267,39 +261,38 @@ impl MqttState { } fn handle_incoming_puback(&mut self, puback: PubAck) -> Result, StateError> { + let pkid = puback.pkid; let p = self .outgoing_pub - .get_mut(puback.pkid as usize) - .ok_or(StateError::Unsolicited(puback.pkid))?; + .get_mut(pkid as usize) + .ok_or(StateError::Unsolicited(pkid))?; - self.last_puback = puback.pkid; + self.last_puback = pkid; if p.take().is_none() { - error!("Unsolicited puback packet: {:?}", puback.pkid); - return Err(StateError::Unsolicited(puback.pkid)); + error!("Unsolicited puback packet: {pkid:?}"); + return Err(StateError::Unsolicited(pkid)); } - let Some(resolver) = self.pub_ack_waiter.remove(&puback.pkid) else { - return Err(StateError::Unsolicited(puback.pkid)); + let Some(resolver) = self.pub_ack_waiter.remove(&pkid) else { + return Err(StateError::Unsolicited(pkid)); }; // Resolve promise for QoS 1 - resolver.resolve(puback.pkid); + resolver.resolve(AckOfPub::PubAck(puback)); self.inflight -= 1; - let packet = self - .check_collision(puback.pkid) - .map(|(publish, resolver)| { - self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); - self.inflight += 1; + let packet = self.check_collision(pkid).map(|(publish, resolver)| { + self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + self.inflight += 1; - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - self.pub_ack_waiter.insert(publish.pkid, resolver); + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + self.pub_ack_waiter.insert(publish.pkid, resolver); - Packet::Publish(publish) - }); + Packet::Publish(publish) + }); Ok(packet) } @@ -326,44 +319,47 @@ impl MqttState { } fn handle_incoming_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { - if !self.incoming_pub.contains(pubrel.pkid as usize) { - error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); - return Err(StateError::Unsolicited(pubrel.pkid)); + let pkid = pubrel.pkid; + if !self.incoming_pub.contains(pkid as usize) { + error!("Unsolicited pubrel packet: {:?}", pkid); + return Err(StateError::Unsolicited(pkid)); } - self.incoming_pub.set(pubrel.pkid as usize, false); - let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); - let pubcomp = PubComp { pkid: pubrel.pkid }; + let resolver = self.pub_rel_waiter.remove(&pkid).unwrap(); + resolver.resolve(AckOfAck::PubRel(pubrel)); + + self.incoming_pub.set(pkid as usize, false); + let event = Event::Outgoing(Outgoing::PubComp(pkid)); + let pubcomp = PubComp { pkid }; self.events.push_back(event); Ok(Some(Packet::PubComp(pubcomp))) } fn handle_incoming_pubcomp(&mut self, pubcomp: PubComp) -> Result, StateError> { - if !self.outgoing_rel.contains(pubcomp.pkid as usize) { - error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); - return Err(StateError::Unsolicited(pubcomp.pkid)); + let pkid = pubcomp.pkid; + if !self.outgoing_rel.contains(pkid as usize) { + error!("Unsolicited pubcomp packet: {pkid:?}"); + return Err(StateError::Unsolicited(pkid)); } - let Some(resolver) = self.pub_ack_waiter.remove(&pubcomp.pkid) else { - return Err(StateError::Unsolicited(pubcomp.pkid)); + let Some(resolver) = self.pub_ack_waiter.remove(&pkid) else { + return Err(StateError::Unsolicited(pkid)); }; // Resolve promise for QoS 2 - resolver.resolve(pubcomp.pkid); + resolver.resolve(AckOfPub::PubComp(pubcomp)); - self.outgoing_rel.set(pubcomp.pkid as usize, false); + self.outgoing_rel.set(pkid as usize, false); self.inflight -= 1; - let packet = self - .check_collision(pubcomp.pkid) - .map(|(publish, resolver)| { - let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - self.pub_ack_waiter.insert(publish.pkid, resolver); + let packet = self.check_collision(pkid).map(|(publish, resolver)| { + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + self.pub_ack_waiter.insert(publish.pkid, resolver); - Packet::Publish(publish) - }); + Packet::Publish(publish) + }); Ok(packet) } @@ -379,7 +375,7 @@ impl MqttState { fn outgoing_publish( &mut self, mut publish: Publish, - resolver: Resolver, + resolver: Resolver, ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { @@ -416,7 +412,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); if publish.qos == QoS::AtMostOnce { - resolver.resolve(publish.pkid); + resolver.resolve(AckOfPub::None); } else { self.pub_ack_waiter.insert(publish.pkid, resolver); } @@ -427,7 +423,7 @@ impl MqttState { fn outgoing_pubrel( &mut self, pubrel: PubRel, - resolver: Resolver, + resolver: Resolver, ) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; @@ -446,9 +442,14 @@ impl MqttState { Ok(Some(Packet::PubAck(puback))) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { + fn outgoing_pubrec( + &mut self, + pubrec: PubRec, + resolver: Resolver, + ) -> Result, StateError> { let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); self.events.push_back(event); + self.pub_rel_waiter.insert(pubrec.pkid, resolver); Ok(Some(Packet::PubRec(pubrec))) } @@ -491,7 +492,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - resolver: Resolver, + resolver: Resolver, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -515,7 +516,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - resolver: Resolver, + resolver: Resolver, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -541,7 +542,7 @@ impl MqttState { Ok(Some(Packet::Disconnect)) } - fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Resolver)> { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Resolver)> { if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); diff --git a/rumqttc/src/tokens.rs b/rumqttc/src/tokens.rs index 4d30cac52..040fedb4e 100644 --- a/rumqttc/src/tokens.rs +++ b/rumqttc/src/tokens.rs @@ -6,27 +6,12 @@ use std::{ }; use tokio::sync::oneshot::{self, error::TryRecvError}; -pub trait Reason: Debug + Send {} -impl Reason for T where T: Debug + Send {} - -#[derive(Debug, thiserror::Error)] -#[error("Broker rejected the request, reason: {0:?}")] -pub struct Rejection(Box); - -impl Rejection { - fn new(reason: R) -> Self { - Self(Box::new(reason)) - } -} - #[derive(Debug, thiserror::Error)] pub enum TokenError { #[error("Sender has nothing to send instantly")] Waiting, #[error("Sender side of channel was dropped")] Disconnected, - #[error("Broker rejected the request, reason: {0:?}")] - Rejection(#[from] Rejection), } pub type NoResponse = (); @@ -35,7 +20,7 @@ pub type NoResponse = (); /// 1. Packet is acknowldged by the broker, e.g. QoS 1/2 Publish, Subscribe and Unsubscribe /// 2. QoS 0 packet finishes processing in the [`EventLoop`] pub struct Token { - rx: oneshot::Receiver>, + rx: oneshot::Receiver, } impl Future for Token { @@ -45,8 +30,7 @@ impl Future for Token { let polled = unsafe { self.map_unchecked_mut(|s| &mut s.rx) }.poll(cx); match polled { - Poll::Ready(Ok(Ok(p))) => Poll::Ready(Ok(p)), - Poll::Ready(Ok(Err(e))) => Poll::Ready(Err(TokenError::Rejection(e))), + Poll::Ready(Ok(p)) => Poll::Ready(Ok(p)), Poll::Ready(Err(_)) => Poll::Ready(Err(TokenError::Disconnected)), Poll::Pending => Poll::Pending, } @@ -66,8 +50,7 @@ impl Token { pub fn wait(self) -> Result { self.rx .blocking_recv() - .map_err(|_| TokenError::Disconnected)? - .map_err(TokenError::Rejection) + .map_err(|_| TokenError::Disconnected) } /// Attempts to check if the packet handling has been completed, without blocking the current thread. @@ -77,17 +60,16 @@ impl Token { /// Multiple calls to this functions can fail with [`TokenError::Disconnected`] /// if the promise has already been resolved. pub fn check(&mut self) -> Result { - match self.rx.try_recv() { - Ok(r) => r.map_err(TokenError::Rejection), - Err(TryRecvError::Empty) => Err(TokenError::Waiting), - Err(TryRecvError::Closed) => Err(TokenError::Disconnected), - } + self.rx.try_recv().map_err(|e| match e { + TryRecvError::Empty => TokenError::Waiting, + TryRecvError::Closed => TokenError::Disconnected, + }) } } #[derive(Debug)] pub struct Resolver { - tx: oneshot::Sender>, + tx: oneshot::Sender, } impl Resolver { @@ -105,13 +87,7 @@ impl Resolver { } pub fn resolve(self, resolved: T) { - if self.tx.send(Ok(resolved)).is_err() { - trace!("Promise was dropped") - } - } - - pub fn reject(self, reasons: R) { - if self.tx.send(Err(Rejection::new(reasons))).is_err() { + if self.tx.send(resolved).is_err() { trace!("Promise was dropped") } } diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 6f537f4d7..d08de1364 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -3,13 +3,13 @@ use std::time::Duration; use super::mqttbytes::v5::{ - Filter, PubAck, PubRec, Publish, PublishProperties, Subscribe, SubscribeProperties, - Unsubscribe, UnsubscribeProperties, + Filter, PubAck, PubRec, Publish, PublishProperties, SubAck, Subscribe, SubscribeProperties, + UnsubAck, Unsubscribe, UnsubscribeProperties, }; use super::mqttbytes::QoS; -use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; +use super::{AckOfAck, AckOfPub, ConnectionError, Event, EventLoop, MqttOptions, Request}; use crate::tokens::{NoResponse, Resolver, Token}; -use crate::{valid_filter, valid_topic, Pkid}; +use crate::{valid_filter, valid_topic}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -79,7 +79,7 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -104,7 +104,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -119,7 +119,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -135,7 +135,7 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -160,7 +160,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -174,7 +174,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -183,7 +183,7 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result, ClientError> { + pub async fn ack(&self, publish: &Publish) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let ack = get_ack_req(publish, resolver); @@ -195,7 +195,7 @@ impl AsyncClient { } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { @@ -213,7 +213,7 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: Option, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, { @@ -234,7 +234,7 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: PublishProperties, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, { @@ -248,7 +248,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, { @@ -262,7 +262,7 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result, ClientError> { + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); @@ -281,7 +281,7 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, Some(properties)).await } @@ -289,7 +289,7 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, None).await } @@ -299,7 +299,7 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result, ClientError> { + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); @@ -318,7 +318,7 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.handle_try_subscribe(topic, qos, Some(properties)) } @@ -326,7 +326,7 @@ impl AsyncClient { &self, topic: S, qos: QoS, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.handle_try_subscribe(topic, qos, None) } @@ -335,7 +335,7 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result, ClientError> + ) -> Result, ClientError> where T: IntoIterator, { @@ -355,14 +355,14 @@ impl AsyncClient { &self, topics: T, properties: SubscribeProperties, - ) -> Result, ClientError> + ) -> Result, ClientError> where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)).await } - pub async fn subscribe_many(&self, topics: T) -> Result, ClientError> + pub async fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -374,7 +374,7 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result, ClientError> + ) -> Result, ClientError> where T: IntoIterator, { @@ -394,14 +394,14 @@ impl AsyncClient { &self, topics: T, properties: SubscribeProperties, - ) -> Result, ClientError> + ) -> Result, ClientError> where T: IntoIterator, { self.handle_try_subscribe_many(topics, Some(properties)) } - pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -413,7 +413,7 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result, ClientError> { + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); let request = Request::Unsubscribe(unsubscribe, resolver); @@ -426,11 +426,14 @@ impl AsyncClient { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.handle_unsubscribe(topic, Some(properties)).await } - pub async fn unsubscribe>(&self, topic: S) -> Result, ClientError> { + pub async fn unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { self.handle_unsubscribe(topic, None).await } @@ -439,7 +442,7 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result, ClientError> { + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); let request = Request::Unsubscribe(unsubscribe, resolver); @@ -452,11 +455,14 @@ impl AsyncClient { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.handle_try_unsubscribe(topic, Some(properties)) } - pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { + pub fn try_unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { self.handle_try_unsubscribe(topic, None) } @@ -479,10 +485,10 @@ impl AsyncClient { } } -fn get_ack_req(publish: &Publish, resolver: Resolver<()>) -> Option { +fn get_ack_req(publish: &Publish, resolver: Resolver) -> Option { let ack = match publish.qos { QoS::AtMostOnce => { - resolver.resolve(()); + resolver.resolve(AckOfAck::None); return None; } QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid, None), resolver), @@ -541,7 +547,7 @@ impl Client { retain: bool, payload: P, properties: Option, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -566,7 +572,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -580,7 +586,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -595,7 +601,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -610,7 +616,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result, ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -619,7 +625,7 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result, ClientError> { + pub fn ack(&self, publish: &Publish) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let ack = get_ack_req(publish, resolver); @@ -631,7 +637,7 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { self.client.try_ack(publish) } @@ -641,7 +647,7 @@ impl Client { topic: S, qos: QoS, properties: Option, - ) -> Result, ClientError> { + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); @@ -660,7 +666,7 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, Some(properties)) } @@ -668,7 +674,7 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, None) } @@ -678,7 +684,7 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.client .try_subscribe_with_properties(topic, qos, properties) } @@ -687,7 +693,7 @@ impl Client { &self, topic: S, qos: QoS, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.client.try_subscribe(topic, qos) } @@ -696,7 +702,7 @@ impl Client { &self, topics: T, properties: Option, - ) -> Result, ClientError> + ) -> Result, ClientError> where T: IntoIterator, { @@ -716,14 +722,14 @@ impl Client { &self, topics: T, properties: SubscribeProperties, - ) -> Result, ClientError> + ) -> Result, ClientError> where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)) } - pub fn subscribe_many(&self, topics: T) -> Result, ClientError> + pub fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -734,7 +740,7 @@ impl Client { &self, topics: T, properties: SubscribeProperties, - ) -> Result, ClientError> + ) -> Result, ClientError> where T: IntoIterator, { @@ -742,7 +748,7 @@ impl Client { .try_subscribe_many_with_properties(topics, properties) } - pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -754,7 +760,7 @@ impl Client { &self, topic: S, properties: Option, - ) -> Result, ClientError> { + ) -> Result, ClientError> { let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); let request = Request::Unsubscribe(unsubscribe, resolver); @@ -767,11 +773,11 @@ impl Client { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.handle_unsubscribe(topic, Some(properties)) } - pub fn unsubscribe>(&self, topic: S) -> Result, ClientError> { + pub fn unsubscribe>(&self, topic: S) -> Result, ClientError> { self.handle_unsubscribe(topic, None) } @@ -780,12 +786,15 @@ impl Client { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result, ClientError> { + ) -> Result, ClientError> { self.client .try_unsubscribe_with_properties(topic, properties) } - pub fn try_unsubscribe>(&self, topic: S) -> Result, ClientError> { + pub fn try_unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { self.client.try_unsubscribe(topic) } diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 22b1942c2..d1b044b31 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -15,8 +15,7 @@ pub mod mqttbytes; mod state; use crate::tokens::Resolver; -use crate::{NetworkOptions, Transport}; -use crate::{Outgoing, Pkid}; +use crate::{NetworkOptions, Outgoing, Transport}; use mqttbytes::v5::*; @@ -32,16 +31,31 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; +/// Used to encapsulate all publish acknowledgents in v5 +#[derive(Debug)] +pub enum AckOfPub { + PubAck(PubAck), + PubComp(PubComp), + None, +} + +/// Used to encapsulate all ack/pubrel acknowledgements in v5 +#[derive(Debug)] +pub enum AckOfAck { + None, + PubRel(PubRel), +} + /// Requests by the client to mqtt event loop. Request are /// handled one by one. #[derive(Debug)] pub enum Request { - Publish(Publish, Resolver), - PubAck(PubAck, Resolver<()>), - PubRec(PubRec, Resolver<()>), - PubRel(PubRel, Resolver), - Subscribe(Subscribe, Resolver), - Unsubscribe(Unsubscribe, Resolver), + Publish(Publish, Resolver), + PubAck(PubAck, Resolver), + PubRec(PubRec, Resolver), + PubRel(PubRel, Resolver), + Subscribe(Subscribe, Resolver), + Unsubscribe(Unsubscribe, Resolver), Disconnect(Resolver<()>), PingReq, } diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 6f9e430c3..d291fbd5a 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -10,7 +10,7 @@ use super::{ }, Error as MqttError, QoS, }, - Event, Incoming, Outgoing, Request, + AckOfAck, AckOfPub, Event, Incoming, Outgoing, Request, }; use bytes::Bytes; @@ -103,7 +103,7 @@ pub struct MqttState { /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, /// Last collision due to broker not acking in order - pub collision: Option<(Publish, Resolver)>, + pub collision: Option<(Publish, Resolver)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -116,12 +116,14 @@ pub struct MqttState { pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, - /// Waiters for publish acknowledgements - pub_ack_waiter: HashMap>, + /// Waiters for publish acknowledgements, qos 1/2 + pub_ack_waiter: HashMap>, + /// Waiters for PubRel, qos 2 + pub_rel_waiter: HashMap>, /// Waiters for subscribe acknowledgements - sub_ack_waiter: HashMap>, + sub_ack_waiter: HashMap>, /// Waiters for unsubscribe acknowledgements - unsub_ack_waiter: HashMap>, + unsub_ack_waiter: HashMap>, } impl MqttState { @@ -150,6 +152,7 @@ impl MqttState { max_outgoing_inflight: max_inflight, max_outgoing_inflight_upper_limit: max_inflight, pub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + pub_rel_waiter: HashMap::with_capacity(max_inflight as usize), sub_ack_waiter: HashMap::with_capacity(max_inflight as usize), unsub_ack_waiter: HashMap::with_capacity(max_inflight as usize), } @@ -209,13 +212,10 @@ impl MqttState { self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? } Request::PubAck(puback, resolver) => { - resolver.resolve(()); + resolver.resolve(super::AckOfAck::None); self.outgoing_puback(puback)? } - Request::PubRec(pubrec, resolver) => { - resolver.resolve(()); - self.outgoing_pubrec(pubrec)? - } + Request::PubRec(pubrec, resolver) => self.outgoing_pubrec(pubrec, resolver)?, }; self.last_outgoing = Instant::now(); @@ -275,15 +275,7 @@ impl MqttState { } } - if suback - .return_codes - .iter() - .all(|r| matches!(r, SubscribeReasonCode::Success(_))) - { - resolver.resolve(suback.pkid); - } else { - resolver.reject(suback.return_codes); - } + resolver.resolve(suback); Ok(None) } @@ -302,15 +294,7 @@ impl MqttState { } } - if unsuback - .reasons - .iter() - .all(|r| matches!(r, UnsubAckReason::Success)) - { - resolver.resolve(unsuback.pkid); - } else { - resolver.reject(unsuback.reasons); - } + resolver.resolve(unsuback); Ok(None) } @@ -393,7 +377,8 @@ impl MqttState { if !self.manual_acks { let pubrec = PubRec::new(pkid, None); - return self.outgoing_pubrec(pubrec); + let (resolver, _) = Resolver::new(); + return self.outgoing_pubrec(pubrec, resolver); } Ok(None) } @@ -412,11 +397,7 @@ impl MqttState { .take(); // Resolve promise for QoS 1 - if puback.reason == PubAckReason::Success { - resolver.resolve(puback.pkid); - } else { - resolver.reject(puback.reason); - } + resolver.resolve(AckOfPub::PubAck(puback.clone())); self.inflight -= 1; @@ -482,6 +463,9 @@ impl MqttState { } self.incoming_pub.set(pubrel.pkid as usize, false); + let resolver = self.pub_rel_waiter.remove(&pubrel.pkid).unwrap(); + resolver.resolve(AckOfAck::PubRel(pubrel.clone())); + if pubrel.reason != PubRelReason::Success { warn!( "PubRel Pkid = {:?}, reason: {:?}", @@ -503,11 +487,7 @@ impl MqttState { }; // Resolve promise for QoS 2 - if pubcomp.reason == PubCompReason::Success { - resolver.resolve(pubcomp.pkid); - } else { - resolver.reject(pubcomp.reason); - } + resolver.resolve(AckOfPub::PubComp(pubcomp.clone())); self.outgoing_rel.set(pubcomp.pkid as usize, false); let outgoing = self @@ -543,7 +523,7 @@ impl MqttState { fn outgoing_publish( &mut self, mut publish: Publish, - resolver: Resolver, + resolver: Resolver, ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { @@ -595,7 +575,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); if publish.qos == QoS::AtMostOnce { - resolver.resolve(0); + resolver.resolve(AckOfPub::None) } else { self.pub_ack_waiter.insert(publish.pkid, resolver); } @@ -606,7 +586,7 @@ impl MqttState { fn outgoing_pubrel( &mut self, pubrel: PubRel, - resolver: Resolver, + resolver: Resolver, ) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; @@ -627,10 +607,15 @@ impl MqttState { Ok(Some(Packet::PubAck(puback))) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { + fn outgoing_pubrec( + &mut self, + pubrec: PubRec, + resolver: Resolver, + ) -> Result, StateError> { let pkid = pubrec.pkid; let event = Event::Outgoing(Outgoing::PubRec(pkid)); self.events.push_back(event); + self.pub_rel_waiter.insert(pubrec.pkid, resolver); Ok(Some(Packet::PubRec(pubrec))) } @@ -670,7 +655,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - resolver: Resolver, + resolver: Resolver, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -695,7 +680,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - resolver: Resolver, + resolver: Resolver, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -724,7 +709,7 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } - fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Resolver)> { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Resolver)> { if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 2915d3e8e..098f697ef 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -622,7 +622,7 @@ async fn resolve_on_qos0_before_write_to_tcp_buffer() { .await .unwrap() .unwrap(), - 0 + AckOfPub::None ); // Verify the packet still reached broker @@ -704,7 +704,7 @@ async fn resolve_on_qos1_ack_from_broker() { .await .unwrap() .unwrap(), - 1 + AckOfPub::PubAck(PubAck { pkid: 1 }) ); } @@ -777,7 +777,7 @@ async fn resolve_on_qos2_ack_from_broker() { .await .unwrap() .unwrap(), - 1 + AckOfPub::PubComp(PubComp { pkid: 1 }) ); } @@ -839,7 +839,8 @@ async fn resolve_on_sub_ack_from_broker() { timeout(Duration::from_secs(1), &mut token) .await .unwrap() - .unwrap(), + .unwrap() + .pkid, 1 ); } @@ -894,6 +895,6 @@ async fn resolve_on_unsub_ack_from_broker() { .await .unwrap() .unwrap(), - 1 + UnsubAck { pkid: 1 } ); } From c7ec9b4a033e3d0e6dc3ad7d27e0c05731ffdbc6 Mon Sep 17 00:00:00 2001 From: swanandx <73115739+swanandx@users.noreply.github.com> Date: Fri, 28 Feb 2025 11:33:17 +0530 Subject: [PATCH 30/30] fix: clear waiting subacks and unsubacks state --- rumqttc/src/state.rs | 6 ++++++ rumqttc/src/v5/state.rs | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index b1f918d09..f52bcdf32 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -134,6 +134,12 @@ impl MqttState { let request = Request::PubRel(PubRel::new(pkid), resolver); pending.push(request); } + + // we don't retransmit subscribe and unsubscribe packet + // so we can clear their state + self.sub_ack_waiter.clear(); + self.unsub_ack_waiter.clear(); + self.outgoing_rel.clear(); // remove packet ids of incoming qos2 publishes diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index d291fbd5a..b49a2c2cd 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -176,6 +176,12 @@ impl MqttState { let request = Request::PubRel(PubRel::new(pkid as u16, None), resolver); pending.push(request); } + + // we don't retransmit subscribe and unsubscribe packet + // so we can clear their state + self.sub_ack_waiter.clear(); + self.unsub_ack_waiter.clear(); + self.outgoing_rel.clear(); // remove packed ids of incoming qos2 publishes