Skip to content

Commit 1f58f9d

Browse files
authored
Merge pull request #431 from ReactiveBayes/uselimitedbuffer-for-stats-callbacks
2 parents 2c31561 + 7d2255e commit 1f58f9d

File tree

2 files changed

+50
-24
lines changed

2 files changed

+50
-24
lines changed

docs/src/manuals/debugging.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,12 @@ The `RxInferBenchmarkCallbacks` structure collects timestamps at various stages
265265
```@docs
266266
RxInferBenchmarkCallbacks
267267
RxInfer.get_benchmark_stats
268+
RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY
268269
```
269270

271+
!!! note
272+
By default, the `RxInferBenchmarkCallbacks` structure uses a circular buffer with a limited capacity to store timestamps. This helps limit memory usage in long-running applications. You can change the buffer capacity by passing a different value to the `capacity` keyword argument of the `RxInferBenchmarkCallbacks` constructor.
273+
270274
This information can be used to:
271275
- Track performance statistics (min/max/average) of your inference procedure
272276
- Identify performance variability across runs

src/inference/benchmarkcallbacks.jl

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,38 @@
11
using PrettyTables
22
using PrettyTables.Printf
3+
using DataStructures: CircularBuffer
34

45
export RxInferBenchmarkCallbacks
56

67
"""
7-
RxInferBenchmarkCallbacks
8+
DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY
9+
10+
The default capacity of the circular buffers used to store timestamps in the `RxInferBenchmarkCallbacks` structure.
11+
"""
12+
const DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY = 1000
13+
14+
"""
15+
RxInferBenchmarkCallbacks(; capacity = RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY)
816
917
A callback structure for collecting timing information during the inference procedure.
1018
This structure collects timestamps for various stages of the inference process and aggregates
1119
them across multiple runs, allowing you to track performance statistics (min/max/average/etc.)
1220
of your model's creation and inference procedure. The structure supports pretty printing by default,
1321
displaying timing statistics in a human-readable format.
1422
23+
The structure uses circular buffers with a default capacity of $(DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY) entries to store timestamps,
24+
which helps to limit memory usage in long-running applications. Use `RxInferBenchmarkCallbacks(; capacity = N)` to change the buffer capacity.
25+
See also [`RxInfer.get_benchmark_stats(callbacks)`](@ref).
26+
1527
# Fields
16-
- `before_model_creation_ts`: Vector of timestamps before model creation
17-
- `after_model_creation_ts`: Vector of timestamps after model creation
18-
- `before_inference_ts`: Vector of timestamps before inference starts
19-
- `after_inference_ts`: Vector of timestamps after inference ends
20-
- `before_iteration_ts`: Vector of vectors of timestamps before each iteration
21-
- `after_iteration_ts`: Vector of vectors of timestamps after each iteration
22-
- `before_autostart_ts`: Vector of timestamps before autostart
23-
- `after_autostart_ts`: Vector of timestamps after autostart
28+
- `before_model_creation_ts`: CircularBuffer of timestamps before model creation
29+
- `after_model_creation_ts`: CircularBuffer of timestamps after model creation
30+
- `before_inference_ts`: CircularBuffer of timestamps before inference starts
31+
- `after_inference_ts`: CircularBuffer of timestamps after inference ends
32+
- `before_iteration_ts`: CircularBuffer of vectors of timestamps before each iteration
33+
- `after_iteration_ts`: CircularBuffer of vectors of timestamps after each iteration
34+
- `before_autostart_ts`: CircularBuffer of timestamps before autostart
35+
- `after_autostart_ts`: CircularBuffer of timestamps after autostart
2436
2537
# Example
2638
```julia
@@ -41,19 +53,31 @@ callbacks
4153
```
4254
"""
4355
struct RxInferBenchmarkCallbacks
44-
before_model_creation_ts::Vector{UInt64}
45-
after_model_creation_ts::Vector{UInt64}
46-
before_inference_ts::Vector{UInt64}
47-
after_inference_ts::Vector{UInt64}
48-
before_iteration_ts::Vector{Vector{UInt64}}
49-
after_iteration_ts::Vector{Vector{UInt64}}
50-
before_autostart_ts::Vector{UInt64}
51-
after_autostart_ts::Vector{UInt64}
56+
before_model_creation_ts::CircularBuffer{UInt64}
57+
after_model_creation_ts::CircularBuffer{UInt64}
58+
before_inference_ts::CircularBuffer{UInt64}
59+
after_inference_ts::CircularBuffer{UInt64}
60+
before_iteration_ts::CircularBuffer{Vector{UInt64}}
61+
after_iteration_ts::CircularBuffer{Vector{UInt64}}
62+
before_autostart_ts::CircularBuffer{UInt64}
63+
after_autostart_ts::CircularBuffer{UInt64}
5264
end
5365

54-
RxInferBenchmarkCallbacks() = RxInferBenchmarkCallbacks(UInt64[], UInt64[], UInt64[], UInt64[], Vector{UInt64}[], Vector{UInt64}[], UInt64[], UInt64[])
66+
function RxInferBenchmarkCallbacks(; capacity = DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY)
67+
RxInferBenchmarkCallbacks(
68+
CircularBuffer{UInt64}(capacity),
69+
CircularBuffer{UInt64}(capacity),
70+
CircularBuffer{UInt64}(capacity),
71+
CircularBuffer{UInt64}(capacity),
72+
CircularBuffer{Vector{UInt64}}(capacity),
73+
CircularBuffer{Vector{UInt64}}(capacity),
74+
CircularBuffer{UInt64}(capacity),
75+
CircularBuffer{UInt64}(capacity)
76+
)
77+
end
5578

5679
check_available_callbacks(warn, callbacks::RxInferBenchmarkCallbacks, ::Val{AvailableCallbacks}) where {AvailableCallbacks} = nothing
80+
inference_get_callback(callbacks::RxInferBenchmarkCallbacks, name::Symbol) = nothing
5781

5882
Base.isempty(callbacks::RxInferBenchmarkCallbacks) = isempty(callbacks.before_model_creation_ts)
5983

@@ -79,8 +103,6 @@ function inference_invoke_callback(callbacks::RxInferBenchmarkCallbacks, name::S
79103
end
80104
end
81105

82-
inference_get_callback(callbacks::RxInferBenchmarkCallbacks, name::Symbol) = nothing
83-
84106
function prettytime(t::Union{UInt64, Float64})
85107
if t < 1e3
86108
value, units = t, "ns"
@@ -112,15 +134,15 @@ Each row represents a different operation (model creation, inference, iteration,
112134
Times are in nanoseconds.
113135
"""
114136
function get_benchmark_stats(callbacks::RxInferBenchmarkCallbacks)
115-
model_creation_time = callbacks.after_model_creation_ts .- callbacks.before_model_creation_ts
137+
model_creation_time = collect(callbacks.after_model_creation_ts) .- collect(callbacks.before_model_creation_ts)
116138
stats_to_show = [("Model creation", model_creation_time)]
117-
inference_time = callbacks.after_inference_ts .- callbacks.before_inference_ts
118-
iteration_time = [callbacks.after_iteration_ts[i] .- callbacks.before_iteration_ts[i] for i in 1:length(callbacks.before_iteration_ts)]
139+
inference_time = collect(callbacks.after_inference_ts) .- collect(callbacks.before_inference_ts)
140+
iteration_time = [collect(callbacks.after_iteration_ts[i]) .- collect(callbacks.before_iteration_ts[i]) for i in 1:length(callbacks.before_iteration_ts)]
119141
if length(inference_time) > 0
120142
push!(stats_to_show, ("Inference", inference_time))
121143
push!(stats_to_show, ("Iteration", reshape(stack(iteration_time), :)))
122144
end
123-
autostart_time = callbacks.after_autostart_ts .- callbacks.before_autostart_ts
145+
autostart_time = collect(callbacks.after_autostart_ts) .- collect(callbacks.before_autostart_ts)
124146
if length(autostart_time) > 0
125147
push!(stats_to_show, ("Autostart", autostart_time))
126148
end

0 commit comments

Comments
 (0)