Skip to content

Commit fffbe65

Browse files
authored
Merge pull request #430 from ReactiveBayes/splitstatscallbacks
Split show for benchmarkcallbacks
2 parents 8c2fe99 + 1f58f9d commit fffbe65

File tree

3 files changed

+89
-33
lines changed

3 files changed

+89
-33
lines changed

docs/src/manuals/debugging.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,13 @@ The `RxInferBenchmarkCallbacks` structure collects timestamps at various stages
264264

265265
```@docs
266266
RxInferBenchmarkCallbacks
267+
RxInfer.get_benchmark_stats
268+
RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY
267269
```
268270

271+
!!! note
272+
By default, the `RxInferBenchmarkCallbacks` structure uses a circular buffer with a limited capacity to store timestamps. This helps limit memory usage in long-running applications. You can change the buffer capacity by passing a different value to the `capacity` keyword argument of the `RxInferBenchmarkCallbacks` constructor.
273+
269274
This information can be used to:
270275
- Track performance statistics (min/max/average) of your inference procedure
271276
- Identify performance variability across runs

src/inference/benchmarkcallbacks.jl

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,38 @@
11
using PrettyTables
22
using PrettyTables.Printf
3+
using DataStructures: CircularBuffer
34

45
export RxInferBenchmarkCallbacks
56

67
"""
7-
RxInferBenchmarkCallbacks
8+
DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY
9+
10+
The default capacity of the circular buffers used to store timestamps in the `RxInferBenchmarkCallbacks` structure.
11+
"""
12+
const DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY = 1000
13+
14+
"""
15+
RxInferBenchmarkCallbacks(; capacity = RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY)
816
917
A callback structure for collecting timing information during the inference procedure.
1018
This structure collects timestamps for various stages of the inference process and aggregates
1119
them across multiple runs, allowing you to track performance statistics (min/max/average/etc.)
1220
of your model's creation and inference procedure. The structure supports pretty printing by default,
1321
displaying timing statistics in a human-readable format.
1422
23+
The structure uses circular buffers with a default capacity of $(DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY) entries to store timestamps,
24+
which helps to limit memory usage in long-running applications. Use `RxInferBenchmarkCallbacks(; capacity = N)` to change the buffer capacity.
25+
See also [`RxInfer.get_benchmark_stats(callbacks)`](@ref).
26+
1527
# Fields
16-
- `before_model_creation_ts`: Vector of timestamps before model creation
17-
- `after_model_creation_ts`: Vector of timestamps after model creation
18-
- `before_inference_ts`: Vector of timestamps before inference starts
19-
- `after_inference_ts`: Vector of timestamps after inference ends
20-
- `before_iteration_ts`: Vector of vectors of timestamps before each iteration
21-
- `after_iteration_ts`: Vector of vectors of timestamps after each iteration
22-
- `before_autostart_ts`: Vector of timestamps before autostart
23-
- `after_autostart_ts`: Vector of timestamps after autostart
28+
- `before_model_creation_ts`: CircularBuffer of timestamps before model creation
29+
- `after_model_creation_ts`: CircularBuffer of timestamps after model creation
30+
- `before_inference_ts`: CircularBuffer of timestamps before inference starts
31+
- `after_inference_ts`: CircularBuffer of timestamps after inference ends
32+
- `before_iteration_ts`: CircularBuffer of vectors of timestamps before each iteration
33+
- `after_iteration_ts`: CircularBuffer of vectors of timestamps after each iteration
34+
- `before_autostart_ts`: CircularBuffer of timestamps before autostart
35+
- `after_autostart_ts`: CircularBuffer of timestamps after autostart
2436
2537
# Example
2638
```julia
@@ -41,19 +53,31 @@ callbacks
4153
```
4254
"""
4355
struct RxInferBenchmarkCallbacks
44-
before_model_creation_ts::Vector{UInt64}
45-
after_model_creation_ts::Vector{UInt64}
46-
before_inference_ts::Vector{UInt64}
47-
after_inference_ts::Vector{UInt64}
48-
before_iteration_ts::Vector{Vector{UInt64}}
49-
after_iteration_ts::Vector{Vector{UInt64}}
50-
before_autostart_ts::Vector{UInt64}
51-
after_autostart_ts::Vector{UInt64}
56+
before_model_creation_ts::CircularBuffer{UInt64}
57+
after_model_creation_ts::CircularBuffer{UInt64}
58+
before_inference_ts::CircularBuffer{UInt64}
59+
after_inference_ts::CircularBuffer{UInt64}
60+
before_iteration_ts::CircularBuffer{Vector{UInt64}}
61+
after_iteration_ts::CircularBuffer{Vector{UInt64}}
62+
before_autostart_ts::CircularBuffer{UInt64}
63+
after_autostart_ts::CircularBuffer{UInt64}
5264
end
5365

54-
RxInferBenchmarkCallbacks() = RxInferBenchmarkCallbacks(UInt64[], UInt64[], UInt64[], UInt64[], Vector{UInt64}[], Vector{UInt64}[], UInt64[], UInt64[])
66+
function RxInferBenchmarkCallbacks(; capacity = DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY)
67+
RxInferBenchmarkCallbacks(
68+
CircularBuffer{UInt64}(capacity),
69+
CircularBuffer{UInt64}(capacity),
70+
CircularBuffer{UInt64}(capacity),
71+
CircularBuffer{UInt64}(capacity),
72+
CircularBuffer{Vector{UInt64}}(capacity),
73+
CircularBuffer{Vector{UInt64}}(capacity),
74+
CircularBuffer{UInt64}(capacity),
75+
CircularBuffer{UInt64}(capacity)
76+
)
77+
end
5578

5679
check_available_callbacks(warn, callbacks::RxInferBenchmarkCallbacks, ::Val{AvailableCallbacks}) where {AvailableCallbacks} = nothing
80+
inference_get_callback(callbacks::RxInferBenchmarkCallbacks, name::Symbol) = nothing
5781

5882
Base.isempty(callbacks::RxInferBenchmarkCallbacks) = isempty(callbacks.before_model_creation_ts)
5983

@@ -79,8 +103,6 @@ function inference_invoke_callback(callbacks::RxInferBenchmarkCallbacks, name::S
79103
end
80104
end
81105

82-
inference_get_callback(callbacks::RxInferBenchmarkCallbacks, name::Symbol) = nothing
83-
84106
function prettytime(t::Union{UInt64, Float64})
85107
if t < 1e3
86108
value, units = t, "ns"
@@ -96,24 +118,31 @@ end
96118

97119
prettytime(s) = s
98120

99-
function Base.show(io::IO, callbacks::RxInferBenchmarkCallbacks)
100-
if isempty(callbacks)
101-
return nothing
102-
end
103-
104-
header = (["Operation", "Min", "Max", "Mean", "Median", "Std"],)
105-
106-
print(io, "RxInfer inference benchmark statistics: $(length(callbacks.before_model_creation_ts)) evaluations \n")
107-
108-
model_creation_time = callbacks.after_model_creation_ts .- callbacks.before_model_creation_ts
121+
"""
122+
get_benchmark_stats(callbacks::RxInferBenchmarkCallbacks)
123+
124+
Returns a matrix containing benchmark statistics for different operations in the inference process.
125+
The matrix contains the following columns:
126+
1. Operation name (String)
127+
2. Minimum time (Float64)
128+
3. Maximum time (Float64)
129+
4. Mean time (Float64)
130+
5. Median time (Float64)
131+
6. Standard deviation (Float64)
132+
133+
Each row represents a different operation (model creation, inference, iteration, autostart).
134+
Times are in nanoseconds.
135+
"""
136+
function get_benchmark_stats(callbacks::RxInferBenchmarkCallbacks)
137+
model_creation_time = collect(callbacks.after_model_creation_ts) .- collect(callbacks.before_model_creation_ts)
109138
stats_to_show = [("Model creation", model_creation_time)]
110-
inference_time = callbacks.after_inference_ts .- callbacks.before_inference_ts
111-
iteration_time = [callbacks.after_iteration_ts[i] .- callbacks.before_iteration_ts[i] for i in 1:length(callbacks.before_iteration_ts)]
139+
inference_time = collect(callbacks.after_inference_ts) .- collect(callbacks.before_inference_ts)
140+
iteration_time = [collect(callbacks.after_iteration_ts[i]) .- collect(callbacks.before_iteration_ts[i]) for i in 1:length(callbacks.before_iteration_ts)]
112141
if length(inference_time) > 0
113142
push!(stats_to_show, ("Inference", inference_time))
114143
push!(stats_to_show, ("Iteration", reshape(stack(iteration_time), :)))
115144
end
116-
autostart_time = callbacks.after_autostart_ts .- callbacks.before_autostart_ts
145+
autostart_time = collect(callbacks.after_autostart_ts) .- collect(callbacks.before_autostart_ts)
117146
if length(autostart_time) > 0
118147
push!(stats_to_show, ("Autostart", autostart_time))
119148
end
@@ -128,6 +157,19 @@ function Base.show(io::IO, callbacks::RxInferBenchmarkCallbacks)
128157
data[i, 5] = convert(Float64, median(time))
129158
data[i, 6] = convert(Float64, std(time))
130159
end
160+
return data
161+
end
162+
163+
function Base.show(io::IO, callbacks::RxInferBenchmarkCallbacks)
164+
if isempty(callbacks)
165+
return nothing
166+
end
167+
168+
header = (["Operation", "Min", "Max", "Mean", "Median", "Std"],)
169+
170+
print(io, "RxInfer inference benchmark statistics: $(length(callbacks.before_model_creation_ts)) evaluations \n")
171+
172+
data = get_benchmark_stats(callbacks)
131173
hl_v = Highlighter((data, i, j) -> (j == 3) && (data[i, j] > 10 * data[i, j - 1]), crayon"red bold")
132174
pretty_table(io, data; formatters = (s, i, j) -> prettytime(s), header = header, header_crayon = crayon"yellow bold", tf = tf_unicode_rounded, highlighters = hl_v)
133175
end

test/inference/inference_tests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,15 @@ end
10761076
@test length(last(callbacks.after_iteration_ts)) == 10
10771077
end
10781078

1079+
stats = RxInfer.get_benchmark_stats(callbacks)
1080+
for line in eachrow(stats)
1081+
@test line[2] > 0.0
1082+
@test line[3] > line[2]
1083+
@test line[2] < line[4] < line[3]
1084+
@test line[2] < line[5] < line[3]
1085+
@test !isnan(line[6])
1086+
end
1087+
10791088
@model function kalman_filter(x_prev_mean, x_prev_var, τ_shape, τ_rate, y)
10801089
x_prev ~ Normal(mean = x_prev_mean, variance = x_prev_var)
10811090
τ ~ Gamma(shape = τ_shape, rate = τ_rate)

0 commit comments

Comments
 (0)