1
1
using PrettyTables
2
2
using PrettyTables. Printf
3
+ using DataStructures: CircularBuffer
3
4
4
5
export RxInferBenchmarkCallbacks
5
6
6
7
"""
7
- RxInferBenchmarkCallbacks
8
+ DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY
9
+
10
+ The default capacity of the circular buffers used to store timestamps in the `RxInferBenchmarkCallbacks` structure.
11
+ """
12
+ const DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY = 1000
13
+
14
+ """
15
+ RxInferBenchmarkCallbacks(; capacity = RxInfer.DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY)
8
16
9
17
A callback structure for collecting timing information during the inference procedure.
10
18
This structure collects timestamps for various stages of the inference process and aggregates
11
19
them across multiple runs, allowing you to track performance statistics (min/max/average/etc.)
12
20
of your model's creation and inference procedure. The structure supports pretty printing by default,
13
21
displaying timing statistics in a human-readable format.
14
22
23
+ The structure uses circular buffers with a default capacity of $(DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY) entries to store timestamps,
24
+ which helps to limit memory usage in long-running applications. Use `RxInferBenchmarkCallbacks(; capacity = N)` to change the buffer capacity.
25
+ See also [`RxInfer.get_benchmark_stats(callbacks)`](@ref).
26
+
15
27
# Fields
16
- - `before_model_creation_ts`: Vector of timestamps before model creation
17
- - `after_model_creation_ts`: Vector of timestamps after model creation
18
- - `before_inference_ts`: Vector of timestamps before inference starts
19
- - `after_inference_ts`: Vector of timestamps after inference ends
20
- - `before_iteration_ts`: Vector of vectors of timestamps before each iteration
21
- - `after_iteration_ts`: Vector of vectors of timestamps after each iteration
22
- - `before_autostart_ts`: Vector of timestamps before autostart
23
- - `after_autostart_ts`: Vector of timestamps after autostart
28
+ - `before_model_creation_ts`: CircularBuffer of timestamps before model creation
29
+ - `after_model_creation_ts`: CircularBuffer of timestamps after model creation
30
+ - `before_inference_ts`: CircularBuffer of timestamps before inference starts
31
+ - `after_inference_ts`: CircularBuffer of timestamps after inference ends
32
+ - `before_iteration_ts`: CircularBuffer of vectors of timestamps before each iteration
33
+ - `after_iteration_ts`: CircularBuffer of vectors of timestamps after each iteration
34
+ - `before_autostart_ts`: CircularBuffer of timestamps before autostart
35
+ - `after_autostart_ts`: CircularBuffer of timestamps after autostart
24
36
25
37
# Example
26
38
```julia
@@ -41,19 +53,31 @@ callbacks
41
53
```
42
54
"""
43
55
struct RxInferBenchmarkCallbacks
44
- before_model_creation_ts:: Vector {UInt64}
45
- after_model_creation_ts:: Vector {UInt64}
46
- before_inference_ts:: Vector {UInt64}
47
- after_inference_ts:: Vector {UInt64}
48
- before_iteration_ts:: Vector {Vector{UInt64}}
49
- after_iteration_ts:: Vector {Vector{UInt64}}
50
- before_autostart_ts:: Vector {UInt64}
51
- after_autostart_ts:: Vector {UInt64}
56
+ before_model_creation_ts:: CircularBuffer {UInt64}
57
+ after_model_creation_ts:: CircularBuffer {UInt64}
58
+ before_inference_ts:: CircularBuffer {UInt64}
59
+ after_inference_ts:: CircularBuffer {UInt64}
60
+ before_iteration_ts:: CircularBuffer {Vector{UInt64}}
61
+ after_iteration_ts:: CircularBuffer {Vector{UInt64}}
62
+ before_autostart_ts:: CircularBuffer {UInt64}
63
+ after_autostart_ts:: CircularBuffer {UInt64}
52
64
end
53
65
54
- RxInferBenchmarkCallbacks () = RxInferBenchmarkCallbacks (UInt64[], UInt64[], UInt64[], UInt64[], Vector{UInt64}[], Vector{UInt64}[], UInt64[], UInt64[])
66
+ function RxInferBenchmarkCallbacks (; capacity = DEFAULT_BENCHMARK_CALLBACKS_BUFFER_CAPACITY)
67
+ RxInferBenchmarkCallbacks (
68
+ CircularBuffer {UInt64} (capacity),
69
+ CircularBuffer {UInt64} (capacity),
70
+ CircularBuffer {UInt64} (capacity),
71
+ CircularBuffer {UInt64} (capacity),
72
+ CircularBuffer {Vector{UInt64}} (capacity),
73
+ CircularBuffer {Vector{UInt64}} (capacity),
74
+ CircularBuffer {UInt64} (capacity),
75
+ CircularBuffer {UInt64} (capacity)
76
+ )
77
+ end
55
78
56
79
check_available_callbacks (warn, callbacks:: RxInferBenchmarkCallbacks , :: Val{AvailableCallbacks} ) where {AvailableCallbacks} = nothing
80
+ inference_get_callback (callbacks:: RxInferBenchmarkCallbacks , name:: Symbol ) = nothing
57
81
58
82
Base. isempty (callbacks:: RxInferBenchmarkCallbacks ) = isempty (callbacks. before_model_creation_ts)
59
83
@@ -79,8 +103,6 @@ function inference_invoke_callback(callbacks::RxInferBenchmarkCallbacks, name::S
79
103
end
80
104
end
81
105
82
- inference_get_callback (callbacks:: RxInferBenchmarkCallbacks , name:: Symbol ) = nothing
83
-
84
106
function prettytime (t:: Union{UInt64, Float64} )
85
107
if t < 1e3
86
108
value, units = t, " ns"
96
118
97
119
prettytime (s) = s
98
120
99
- function Base. show (io:: IO , callbacks:: RxInferBenchmarkCallbacks )
100
- if isempty (callbacks)
101
- return nothing
102
- end
103
-
104
- header = ([" Operation" , " Min" , " Max" , " Mean" , " Median" , " Std" ],)
105
-
106
- print (io, " RxInfer inference benchmark statistics: $(length (callbacks. before_model_creation_ts)) evaluations \n " )
107
-
108
- model_creation_time = callbacks. after_model_creation_ts .- callbacks. before_model_creation_ts
121
+ """
122
+ get_benchmark_stats(callbacks::RxInferBenchmarkCallbacks)
123
+
124
+ Returns a matrix containing benchmark statistics for different operations in the inference process.
125
+ The matrix contains the following columns:
126
+ 1. Operation name (String)
127
+ 2. Minimum time (Float64)
128
+ 3. Maximum time (Float64)
129
+ 4. Mean time (Float64)
130
+ 5. Median time (Float64)
131
+ 6. Standard deviation (Float64)
132
+
133
+ Each row represents a different operation (model creation, inference, iteration, autostart).
134
+ Times are in nanoseconds.
135
+ """
136
+ function get_benchmark_stats (callbacks:: RxInferBenchmarkCallbacks )
137
+ model_creation_time = collect (callbacks. after_model_creation_ts) .- collect (callbacks. before_model_creation_ts)
109
138
stats_to_show = [(" Model creation" , model_creation_time)]
110
- inference_time = callbacks. after_inference_ts .- callbacks. before_inference_ts
111
- iteration_time = [callbacks. after_iteration_ts[i] .- callbacks. before_iteration_ts[i] for i in 1 : length (callbacks. before_iteration_ts)]
139
+ inference_time = collect ( callbacks. after_inference_ts) .- collect ( callbacks. before_inference_ts)
140
+ iteration_time = [collect ( callbacks. after_iteration_ts[i]) .- collect ( callbacks. before_iteration_ts[i]) for i in 1 : length (callbacks. before_iteration_ts)]
112
141
if length (inference_time) > 0
113
142
push! (stats_to_show, (" Inference" , inference_time))
114
143
push! (stats_to_show, (" Iteration" , reshape (stack (iteration_time), :)))
115
144
end
116
- autostart_time = callbacks. after_autostart_ts .- callbacks. before_autostart_ts
145
+ autostart_time = collect ( callbacks. after_autostart_ts) .- collect ( callbacks. before_autostart_ts)
117
146
if length (autostart_time) > 0
118
147
push! (stats_to_show, (" Autostart" , autostart_time))
119
148
end
@@ -128,6 +157,19 @@ function Base.show(io::IO, callbacks::RxInferBenchmarkCallbacks)
128
157
data[i, 5 ] = convert (Float64, median (time))
129
158
data[i, 6 ] = convert (Float64, std (time))
130
159
end
160
+ return data
161
+ end
162
+
163
+ function Base. show (io:: IO , callbacks:: RxInferBenchmarkCallbacks )
164
+ if isempty (callbacks)
165
+ return nothing
166
+ end
167
+
168
+ header = ([" Operation" , " Min" , " Max" , " Mean" , " Median" , " Std" ],)
169
+
170
+ print (io, " RxInfer inference benchmark statistics: $(length (callbacks. before_model_creation_ts)) evaluations \n " )
171
+
172
+ data = get_benchmark_stats (callbacks)
131
173
hl_v = Highlighter ((data, i, j) -> (j == 3 ) && (data[i, j] > 10 * data[i, j - 1 ]), crayon " red bold" )
132
174
pretty_table (io, data; formatters = (s, i, j) -> prettytime (s), header = header, header_crayon = crayon " yellow bold" , tf = tf_unicode_rounded, highlighters = hl_v)
133
175
end
0 commit comments