Skip to content

Commit 001ae5f

Browse files
authored
Add middleware support (ntex-rs#177)
1 parent f2614ef commit 001ae5f

23 files changed

+381
-238
lines changed

.github/workflows/cov.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
uses: Swatinem/rust-cache@v1.0.1
2626

2727
- name: Generate code coverage
28-
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
28+
run: cargo llvm-cov --features=ntex/compio --workspace --lcov --output-path lcov.info
2929

3030
- name: Upload coverage to Codecov
3131
uses: codecov/codecov-action@v4

.github/workflows/linux.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,20 @@ jobs:
4848
path: target
4949
key: ${{ matrix.version }}-x86_64-unknown-linux-gnu-cargo-build-trimmed-${{ hashFiles('**/Cargo.lock') }}
5050

51-
- name: Run tests
51+
- name: Run tests [tokio]
5252
uses: actions-rs/cargo@v1
5353
timeout-minutes: 40
5454
with:
5555
command: test
5656
args: --all --features=ntex/tokio -- --nocapture
5757

58+
# - name: Run tests [compio]
59+
# uses: actions-rs/cargo@v1
60+
# timeout-minutes: 40
61+
# with:
62+
# command: test
63+
# args: --all --features=ntex/compio -- --nocapture
64+
5865
- name: Install cargo-cache
5966
continue-on-error: true
6067
run: |

.github/workflows/windows.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,4 @@ jobs:
6969
uses: actions-rs/cargo@v1
7070
with:
7171
command: test
72-
args: --all --features=ntex/tokio -- --nocapture
72+
args: --all --features=ntex/compio -- --nocapture

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changes
22

3+
## [4.0.0] - 2024-10-05
4+
5+
* Middlewares support for mqtt server
6+
37
## [3.1.0] - 2024-08-23
48

59
* Derive Hash for the QoS enum #175

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "ntex-mqtt"
3-
version = "3.1.0"
3+
version = "4.0.0"
44
authors = ["ntex contributors <team@ntex.rs>"]
55
description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
66
documentation = "https://docs.rs/ntex-mqtt"
@@ -36,4 +36,4 @@ ntex-tls = "2"
3636
ntex-macros = "0.1"
3737
openssl = "0.10"
3838
test-case = "3.2"
39-
ntex = { version = "2", features = ["tokio", "openssl"] }
39+
ntex = { version = "2", features = ["openssl"] }

examples/basic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ async fn main() -> std::io::Result<()> {
5252
ntex::server::build()
5353
.bind("mqtt", "127.0.0.1:1883", |_| {
5454
MqttServer::new()
55-
.v3(v3::MqttServer::new(handshake_v3).publish(publish_v3))
56-
.v5(v5::MqttServer::new(handshake_v5).publish(publish_v5))
55+
.v3(v3::MqttServer::new(handshake_v3).publish(publish_v3).finish())
56+
.v5(v5::MqttServer::new(handshake_v5).publish(publish_v5).finish())
5757
})?
5858
.workers(1)
5959
.run()

examples/openssl.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ async fn main() -> std::io::Result<()> {
6666
.map_err(|_err| MqttError::Service(ServerError {}))
6767
.and_then(
6868
MqttServer::new()
69-
.v3(v3::MqttServer::new(handshake_v3).publish(publish_v3))
70-
.v5(v5::MqttServer::new(handshake_v5).publish(publish_v5)),
69+
.v3(v3::MqttServer::new(handshake_v3).publish(publish_v3).finish())
70+
.v5(v5::MqttServer::new(handshake_v5).publish(publish_v5).finish()),
7171
)
7272
})?
7373
.workers(1)

examples/session.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,20 @@ async fn main() -> std::io::Result<()> {
9595
ntex::server::build()
9696
.bind("mqtt", "127.0.0.1:1883", |_| {
9797
MqttServer::new()
98-
.v3(v3::MqttServer::new(handshake_v3).publish(fn_factory_with_config(
99-
|session: v3::Session<MySession>| {
98+
.v3(v3::MqttServer::new(handshake_v3)
99+
.publish(fn_factory_with_config(|session: v3::Session<MySession>| {
100100
Ready::Ok::<_, MyServerError>(fn_service(move |req| {
101101
publish_v3(session.clone(), req)
102102
}))
103-
},
104-
)))
105-
.v5(v5::MqttServer::new(handshake_v5).publish(fn_factory_with_config(
106-
|session: v5::Session<MySession>| {
103+
}))
104+
.finish())
105+
.v5(v5::MqttServer::new(handshake_v5)
106+
.publish(fn_factory_with_config(|session: v5::Session<MySession>| {
107107
Ready::Ok::<_, MyServerError>(fn_service(move |req| {
108108
publish_v5(session.clone(), req)
109109
}))
110-
},
111-
)))
110+
}))
111+
.finish())
112112
})?
113113
.workers(1)
114114
.run()

src/inflight.rs

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,96 @@
11
//! Service that limits number of in-flight async requests.
22
use std::{cell::Cell, future::poll_fn, rc::Rc, task::Context, task::Poll};
33

4-
use ntex_service::{Service, ServiceCtx};
4+
use ntex_service::{Middleware, Service, ServiceCtx};
55
use ntex_util::task::LocalWaker;
66

7-
pub(crate) trait SizedRequest {
7+
/// Trait for types that could be sized
8+
pub trait SizedRequest {
89
fn size(&self) -> u32;
910
}
1011

11-
pub(crate) struct InFlightService<S> {
12-
count: Counter,
13-
service: S,
12+
/// Service that can limit number of in-flight async requests.
13+
///
14+
/// Default is 16 in-flight messages and 64kb size
15+
pub struct InFlightService {
16+
max_receive: u16,
17+
max_receive_size: usize,
1418
}
1519

16-
impl<S> InFlightService<S> {
17-
pub(crate) fn new(max_cap: u16, max_size: usize, service: S) -> Self {
18-
Self { service, count: Counter::new(max_cap, max_size) }
20+
impl Default for InFlightService {
21+
fn default() -> Self {
22+
Self { max_receive: 16, max_receive_size: 65535 }
1923
}
2024
}
2125

22-
impl<T, R> Service<R> for InFlightService<T>
26+
impl InFlightService {
27+
/// Create new `InFlightService` middleware
28+
///
29+
/// By default max receive is 16 and max size is 64kb
30+
pub fn new(max_receive: u16, max_receive_size: usize) -> Self {
31+
Self { max_receive, max_receive_size }
32+
}
33+
34+
/// Number of inbound in-flight concurrent messages.
35+
///
36+
/// By default max receive number is set to 16 messages
37+
pub fn max_receive(mut self, val: u16) -> Self {
38+
self.max_receive = val;
39+
self
40+
}
41+
42+
/// Total size of inbound in-flight messages.
43+
///
44+
/// By default total inbound in-flight size is set to 64Kb
45+
pub fn max_receive_size(mut self, val: usize) -> Self {
46+
self.max_receive_size = val;
47+
self
48+
}
49+
}
50+
51+
impl<S> Middleware<S> for InFlightService {
52+
type Service = InFlightServiceImpl<S>;
53+
54+
#[inline]
55+
fn create(&self, service: S) -> Self::Service {
56+
InFlightServiceImpl {
57+
service,
58+
count: Counter::new(self.max_receive, self.max_receive_size),
59+
}
60+
}
61+
}
62+
63+
pub struct InFlightServiceImpl<S> {
64+
count: Counter,
65+
service: S,
66+
}
67+
68+
impl<S, R> Service<R> for InFlightServiceImpl<S>
2369
where
24-
T: Service<R>,
70+
S: Service<R>,
2571
R: SizedRequest + 'static,
2672
{
27-
type Response = T::Response;
28-
type Error = T::Error;
73+
type Response = S::Response;
74+
type Error = S::Error;
2975

3076
ntex_service::forward_shutdown!(service);
3177

3278
#[inline]
33-
async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
79+
async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), S::Error> {
3480
ctx.ready(&self.service).await?;
81+
82+
// check if we have capacity
3583
self.count.available().await;
3684
Ok(())
3785
}
3886

3987
#[inline]
40-
async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<T::Response, T::Error> {
88+
async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<S::Response, S::Error> {
4189
let size = if self.count.0.max_size > 0 { req.size() } else { 0 };
42-
let _task_guard = self.count.get(size);
43-
ctx.call(&self.service, req).await
90+
let task_guard = self.count.get(size);
91+
let result = ctx.call(&self.service, req).await;
92+
drop(task_guard);
93+
result
4494
}
4595
}
4696

@@ -154,7 +204,8 @@ mod tests {
154204
async fn test_inflight() {
155205
let wait_time = Duration::from_millis(50);
156206

157-
let srv = Pipeline::new(InFlightService::new(1, 0, SleepService(wait_time))).bind();
207+
let srv =
208+
Pipeline::new(InFlightService::new(1, 0).create(SleepService(wait_time))).bind();
158209
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
159210

160211
let srv2 = srv.clone();
@@ -173,7 +224,8 @@ mod tests {
173224
async fn test_inflight2() {
174225
let wait_time = Duration::from_millis(50);
175226

176-
let srv = Pipeline::new(InFlightService::new(0, 10, SleepService(wait_time))).bind();
227+
let srv =
228+
Pipeline::new(InFlightService::new(0, 10).create(SleepService(wait_time))).bind();
177229
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
178230

179231
let srv2 = srv.clone();
@@ -227,11 +279,11 @@ mod tests {
227279
async fn test_inflight3() {
228280
let wait_time = Duration::from_millis(50);
229281

230-
let srv = Pipeline::new(InFlightService::new(
231-
1,
232-
10,
233-
Srv2 { dur: wait_time, cnt: Cell::new(false), waker: LocalWaker::new() },
234-
))
282+
let srv = Pipeline::new(InFlightService::new(1, 10).create(Srv2 {
283+
dur: wait_time,
284+
cnt: Cell::new(false),
285+
waker: LocalWaker::new(),
286+
}))
235287
.bind();
236288
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));
237289

src/io.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ mod tests {
789789
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
790790

791791
// write side must be closed, dispatcher waiting for read side to close
792+
sleep(Millis(50)).await;
792793
assert!(client.is_closed());
793794

794795
// close read side
@@ -837,6 +838,7 @@ mod tests {
837838
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
838839

839840
// write side must be closed, dispatcher waiting for read side to close
841+
sleep(Millis(50)).await;
840842
assert!(client.is_closed());
841843

842844
// close read side

0 commit comments

Comments
 (0)