Skip to content

Commit 1ca932e

Browse files
committed
[not working] add cancellation via token
1 parent 1276ded commit 1ca932e

File tree

28 files changed

+346
-171
lines changed

28 files changed

+346
-171
lines changed

engine/Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ tokio = { version = "1", default-features = false, features = [
127127
"macros",
128128
"time",
129129
] }
130+
tokio-util = "0.7"
130131

131132
[workspace.package]
132133
version = "0.203.1"

engine/baml-runtime/src/internal/llm_client/primitive/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ mod anthropic;
3131
mod aws;
3232
mod google;
3333
mod openai;
34-
pub(super) mod request;
34+
pub mod request;
3535
mod stream_request;
3636
mod vertex;
3737

@@ -436,6 +436,7 @@ mod tests {
436436
allow_proxy: bool,
437437
stream: bool,
438438
expose_secrets: bool,
439+
cancellation_token: Option<tokio_util::sync::CancellationToken>,
439440
) -> Result<reqwest::RequestBuilder> {
440441
unimplemented!("Not used in tests")
441442
}

engine/baml-runtime/tests/test_cancellation.rs

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ mod cancellation_tests {
146146

147147
assert!(result.is_ok(), "Stream should complete (with cancellation)");
148148
let (stream_result, _) = result.unwrap();
149-
149+
150150
// Should be cancelled
151151
assert!(stream_result.is_err());
152152
let error_msg = stream_result.unwrap_err().to_string();
@@ -185,7 +185,8 @@ mod cancellation_tests {
185185
"##,
186186
);
187187

188-
let runtime = BamlRuntime::from_file_content(".", &files, HashMap::new()).unwrap();
188+
let runtime =
189+
BamlRuntime::from_file_content(".", &files, HashMap::<&str, &str>::new()).unwrap();
189190
let ctx = runtime.create_ctx_manager(BamlValue::String("test".to_string()), None);
190191

191192
// Create a stream
@@ -364,18 +365,21 @@ mod cancellation_tests {
364365

365366
// Should be cancelled
366367
assert!(result.is_err());
367-
368+
368369
// Callback should not have been called due to immediate cancellation
369370
let final_count = *callback_count.lock().unwrap();
370-
assert_eq!(final_count, 0, "Callbacks should not be invoked after cancellation");
371+
assert_eq!(
372+
final_count, 0,
373+
"Callbacks should not be invoked after cancellation"
374+
);
371375
}
372376

373377
/// Test that cancellation works with sync runtime
374378
#[cfg(not(target_arch = "wasm32"))]
375379
#[test]
376380
fn test_sync_stream_cancellation() {
377381
let rt = tokio::runtime::Runtime::new().unwrap();
378-
382+
379383
rt.block_on(async {
380384
let mut files = HashMap::new();
381385
files.insert(
@@ -452,37 +456,38 @@ mod unit_tests {
452456
#[tokio::test]
453457
async fn test_cancellation_token_basic() {
454458
let token = CancellationToken::new();
455-
459+
456460
// Initially not cancelled
457461
assert!(!token.is_cancelled());
458-
462+
459463
// Cancel the token
460464
token.cancel();
461-
465+
462466
// Now should be cancelled
463467
assert!(token.is_cancelled());
464-
468+
465469
// cancelled() future should complete immediately
466-
let result = tokio::time::timeout(
467-
std::time::Duration::from_millis(100),
468-
token.cancelled()
469-
).await;
470-
471-
assert!(result.is_ok(), "cancelled() future should complete immediately");
470+
let result =
471+
tokio::time::timeout(std::time::Duration::from_millis(100), token.cancelled()).await;
472+
473+
assert!(
474+
result.is_ok(),
475+
"cancelled() future should complete immediately"
476+
);
472477
}
473478

474479
/// Test CancellationToken with tokio::select!
475480
#[tokio::test]
476481
async fn test_cancellation_with_select() {
477482
let token = CancellationToken::new();
478483
let token_clone = token.clone();
479-
484+
480485
// Spawn a task that cancels after delay
481486
tokio::spawn(async move {
482487
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
483488
token_clone.cancel();
484489
});
485-
490+
486491
// Use select to race between work and cancellation
487492
let result = tokio::select! {
488493
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => {
@@ -492,7 +497,7 @@ mod unit_tests {
492497
"cancelled"
493498
}
494499
};
495-
500+
496501
assert_eq!(result, "cancelled");
497502
}
498503

@@ -501,14 +506,14 @@ mod unit_tests {
501506
async fn test_child_token_cancellation() {
502507
let parent_token = CancellationToken::new();
503508
let child_token = parent_token.child_token();
504-
509+
505510
// Initially neither is cancelled
506511
assert!(!parent_token.is_cancelled());
507512
assert!(!child_token.is_cancelled());
508-
513+
509514
// Cancel parent
510515
parent_token.cancel();
511-
516+
512517
// Both should be cancelled
513518
assert!(parent_token.is_cancelled());
514519
assert!(child_token.is_cancelled());

engine/baml-runtime/tests/test_integration_cancellation.rs

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ mod integration_cancellation_tests {
4444
"##,
4545
);
4646

47-
let runtime = BamlRuntime::from_file_content(".", &files, HashMap::new()).unwrap();
47+
let runtime =
48+
BamlRuntime::from_file_content(".", &files, HashMap::<&str, &str>::new()).unwrap();
4849
let ctx = runtime.create_ctx_manager(BamlValue::String("test".to_string()), None);
4950

5051
// Create a stream (simulating TypeScript BamlStream creation)
@@ -93,12 +94,12 @@ mod integration_cancellation_tests {
9394

9495
// Should complete within timeout due to cancellation
9596
assert!(result.is_ok(), "Stream should complete due to cancellation");
96-
97+
9798
let (stream_result, _) = result.unwrap();
98-
99+
99100
// Should be cancelled (not a successful completion)
100101
assert!(stream_result.is_err(), "Stream should be cancelled");
101-
102+
102103
// Should be much faster than the 10-second delay endpoint
103104
assert!(
104105
elapsed < Duration::from_secs(2),
@@ -143,7 +144,8 @@ mod integration_cancellation_tests {
143144
"##,
144145
);
145146

146-
let runtime = BamlRuntime::from_file_content(".", &files, HashMap::new()).unwrap();
147+
let runtime =
148+
BamlRuntime::from_file_content(".", &files, HashMap::<&str, &str>::new()).unwrap();
147149
let ctx = runtime.create_ctx_manager(BamlValue::String("test".to_string()), None);
148150

149151
// Create multiple streams
@@ -167,7 +169,7 @@ mod integration_cancellation_tests {
167169

168170
let token = CancellationToken::new();
169171
stream.set_cancellation_token(token.clone());
170-
172+
171173
streams.push(stream);
172174
tokens.push(token);
173175
}
@@ -190,7 +192,8 @@ mod integration_cancellation_tests {
190192
None,
191193
HashMap::new(),
192194
),
193-
).await;
195+
)
196+
.await;
194197
(i, result)
195198
});
196199
handles.push(handle);
@@ -201,13 +204,21 @@ mod integration_cancellation_tests {
201204

202205
// Check results
203206
for (handle_result, (stream_index, timeout_result)) in results.into_iter().enumerate() {
204-
assert!(handle_result.is_ok(), "Task {} should complete", stream_index);
205-
207+
assert!(
208+
handle_result.is_ok(),
209+
"Task {} should complete",
210+
stream_index
211+
);
212+
206213
let (stream_result, _) = timeout_result.unwrap().unwrap();
207-
214+
208215
if stream_index == 1 {
209216
// Middle stream should be cancelled
210-
assert!(stream_result.is_err(), "Stream {} should be cancelled", stream_index);
217+
assert!(
218+
stream_result.is_err(),
219+
"Stream {} should be cancelled",
220+
stream_index
221+
);
211222
let error_msg = stream_result.unwrap_err().to_string();
212223
assert!(
213224
error_msg.contains("cancelled") || error_msg.contains("canceled"),
@@ -249,7 +260,8 @@ mod integration_cancellation_tests {
249260
"##,
250261
);
251262

252-
let runtime = BamlRuntime::from_file_content(".", &files, HashMap::new()).unwrap();
263+
let runtime =
264+
BamlRuntime::from_file_content(".", &files, HashMap::<&str, &str>::new()).unwrap();
253265
let ctx = runtime.create_ctx_manager(BamlValue::String("test".to_string()), None);
254266

255267
let mut stream = runtime
@@ -299,7 +311,7 @@ mod integration_cancellation_tests {
299311

300312
// Should be cancelled
301313
assert!(result.is_err());
302-
314+
303315
// Events should be minimal due to quick cancellation
304316
let final_count = *event_count.lock().unwrap();
305317
assert!(
@@ -356,7 +368,7 @@ mod integration_cancellation_tests {
356368

357369
let token = CancellationToken::new();
358370
stream.set_cancellation_token(token.clone());
359-
371+
360372
streams.push(stream);
361373
tokens.push(token);
362374
}
@@ -368,19 +380,21 @@ mod integration_cancellation_tests {
368380

369381
// Run all streams - they should all be cancelled quickly
370382
let start_time = std::time::Instant::now();
371-
383+
372384
let mut handles = Vec::new();
373385
for mut stream in streams {
374386
let ctx_clone = ctx.clone();
375387
let handle = tokio::spawn(async move {
376-
stream.run(
377-
None::<fn()>,
378-
None::<fn(baml_runtime::FunctionResult)>,
379-
&ctx_clone,
380-
None,
381-
None,
382-
HashMap::new(),
383-
).await
388+
stream
389+
.run(
390+
None::<fn()>,
391+
None::<fn(baml_runtime::FunctionResult)>,
392+
&ctx_clone,
393+
None,
394+
None,
395+
HashMap::new(),
396+
)
397+
.await
384398
});
385399
handles.push(handle);
386400
}

engine/baml-schema-wasm/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ serde_json.workspace = true
4646
serde-wasm-bindgen = "0.6.5"
4747
time.workspace = true
4848
tokio = { workspace = true, features = ["sync"] }
49+
tokio-util = { workspace = true }
50+
thiserror = { workspace = true }
4951
url.workspace = true
5052
uuid = { version = "1.8", features = ["v4", "js"] }
5153
wasm-bindgen = "=0.2.92"
@@ -55,6 +57,8 @@ web-time.workspace = true
5557
either = "1.8.1"
5658
itertools = "0.13.0"
5759
once_cell.workspace = true
60+
61+
5862
[dependencies.web-sys]
5963
version = "0.3.69"
6064
features = [

0 commit comments

Comments
 (0)