Skip to content

Commit 4c86203

Browse files
authored
Update benchmarks (#367)
* Distribute visibilities to get strong scaling benchmarks * Add warmup time to benchmarks * Enable debug logging * Scatter visibilities. Define parameters. * Clean up padmm benchmarks * Use factories in padmm benchmarks * Don't use hard coded paths * Fix variable names * Add some logging to report where vis data is coming from * Update measurements the same way as mpi version * Clarify names, call the right padmm object * Add placeholders to logging command * Reduce sopt verbosity * Fix bugs and lint. Add smaller test for comparison with serial version * Add info to factory functions * Add FB algorithms, rename to be more descriptive * Help purify look up the ONNX runtime * Install data files for TF models * Add FB benchmark that uses onnx rt * Linting * More linting * Linting++ * One more for the linter * Remove obsolete reference to notinstalled namespace * Use correct models directory * remove onnxrt lookup, should be provided by sopt * Remove duplicate * Add bigger problems and more runtime * Linting * refactor + add serial fb algorithms
1 parent f8bfb0b commit 4c86203

File tree

12 files changed

+455
-350
lines changed

12 files changed

+455
-350
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,5 @@ endif()
6363

6464
add_subdirectory(cpp)
6565

66-
6766
# Exports Purify so other packages can access it
6867
include(export_purify)

cmake_files/dependencies.cmake

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ if(tests) # Adds ctest
6565
include(AddCatchTest)
6666
endif()
6767

68-
if(examples)
69-
find_package(TIFF REQUIRED)
70-
endif()
71-
7268
if(tests OR examples)
7369
file(COPY data DESTINATION .)
7470
endif()

cmake_files/export_purify.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ install(FILES
3838
)
3939

4040
install(EXPORT PurifyTargets DESTINATION share/cmake/purify COMPONENT dev)
41+
install(DIRECTORY "${PROJECT_SOURCE_DIR}/data" DESTINATION "${CMAKE_INSTALL_PREFIX}")

cpp/benchmarks/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ if(dompi)
2525
add_executable(mpi_benchmark_MO_wproj main.cc utilities.cc measurement_operator_wproj.cc)
2626
target_link_libraries(mpi_benchmark_MO_wproj ${MPI_LIBRARIES} benchmark libpurify)
2727
#target_include_directories(mpi_benchmark_MO_wproj PUBLIC "${PROJECT_SOURCE_DIR}/cpp" "${CMAKE_CURRENT_BINARY_DIR}/include")
28-
add_executable(mpi_benchmark_PADMM main.cc utilities.cc padmm_mpi.cc)
29-
target_link_libraries(mpi_benchmark_PADMM ${MPI_LIBRARIES} benchmark libpurify)
28+
add_executable(mpi_benchmark_algorithms main.cc utilities.cc algorithms_mpi.cc)
29+
target_link_libraries(mpi_benchmark_algorithms ${MPI_LIBRARIES} benchmark libpurify)
3030
#target_include_directories(mpi_benchmark_PADMM PUBLIC "${PROJECT_SOURCE_DIR}/cpp" "${CMAKE_CURRENT_BINARY_DIR}/include")
3131
add_executable(mpi_benchmark_WLO main.cc utilities.cc wavelet_operator_mpi.cc)
3232
target_link_libraries(mpi_benchmark_WLO ${MPI_LIBRARIES} benchmark libpurify)

cpp/benchmarks/algorithms.cc

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#include "purify/config.h"
2+
#include "purify/types.h"
3+
#include <array>
4+
#include <benchmark/benchmark.h>
5+
#include "benchmarks/utilities.h"
6+
#include "purify/algorithm_factory.h"
7+
#include "purify/directories.h"
8+
#include "purify/measurement_operator_factory.h"
9+
#include "purify/operators.h"
10+
#include "purify/utilities.h"
11+
#include "purify/wavelet_operator_factory.h"
12+
#include <sopt/imaging_padmm.h>
13+
#include <sopt/relative_variation.h>
14+
#include <sopt/utilities.h>
15+
#include <sopt/wavelets.h>
16+
#include <sopt/wavelets/sara.h>
17+
18+
using namespace purify;
19+
20+
class AlgoFixture : public ::benchmark::Fixture {
21+
public:
22+
void SetUp(const ::benchmark::State &state) {
23+
// Reading image from file and update related quantities
24+
bool newImage = b_utilities::updateImage(state.range(0), m_image, m_imsizex, m_imsizey);
25+
26+
// Generating random uv(w) coverage
27+
bool newMeasurements =
28+
b_utilities::updateMeasurements(state.range(1), m_uv_data, m_epsilon, newImage, m_image);
29+
30+
bool newKernel = m_kernel != state.range(2);
31+
32+
m_kernel = state.range(2);
33+
// creating the measurement operator
34+
const t_real FoV = 1; // deg
35+
const t_real cellsize = FoV / m_imsizex * 60. * 60.;
36+
const bool w_term = false;
37+
m_measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
38+
factory::distributed_measurement_operator::serial, m_uv_data, m_imsizey, m_imsizex,
39+
cellsize, cellsize, 2, kernels::kernel::kb, m_kernel, m_kernel, w_term);
40+
41+
t_real const m_sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
42+
}
43+
44+
void TearDown(const ::benchmark::State &state) {}
45+
46+
t_real m_epsilon;
47+
t_uint m_counter;
48+
t_real m_sigma;
49+
std::vector<std::tuple<std::string, t_uint>> const m_sara{
50+
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
51+
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
52+
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
53+
54+
Image<t_complex> m_image;
55+
t_uint m_imsizex;
56+
t_uint m_imsizey;
57+
58+
utilities::vis_params m_uv_data;
59+
60+
t_uint m_kernel;
61+
std::shared_ptr<sopt::LinearTransform<Vector<t_complex>> const> m_measurements_transform;
62+
std::shared_ptr<sopt::algorithm::ImagingProximalADMM<t_complex>> m_padmm;
63+
std::shared_ptr<sopt::algorithm::ImagingForwardBackward<t_complex>> m_fb;
64+
};
65+
66+
BENCHMARK_DEFINE_F(AlgoFixture, Padmm)(benchmark::State &state) {
67+
// Benchmark the application of the algorithm
68+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
69+
factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
70+
71+
m_padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
72+
factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma,
73+
m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3, 1e-2, 50,
74+
1.0, 1.0);
75+
76+
while (state.KeepRunning()) {
77+
auto start = std::chrono::high_resolution_clock::now();
78+
(*m_padmm)();
79+
auto end = std::chrono::high_resolution_clock::now();
80+
state.SetIterationTime(b_utilities::duration(start, end));
81+
}
82+
}
83+
84+
BENCHMARK_DEFINE_F(AlgoFixture, ForwardBackward)(benchmark::State &state) {
85+
// Benchmark the application of the algorithm
86+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
87+
factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
88+
89+
t_real const beta = m_sigma * m_sigma;
90+
t_real const gamma = 0.0001;
91+
92+
m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
93+
factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma,
94+
beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3,
95+
1e-2, 50, 1.0);
96+
97+
while (state.KeepRunning()) {
98+
auto start = std::chrono::high_resolution_clock::now();
99+
(*m_fb)();
100+
auto end = std::chrono::high_resolution_clock::now();
101+
state.SetIterationTime(b_utilities::duration(start, end));
102+
}
103+
}
104+
105+
#ifdef PURIFY_ONNXRT
106+
BENCHMARK_DEFINE_F(AlgoFixture, ForwardBackwardOnnx)(benchmark::State &state) {
107+
// Benchmark the application of the algorithm
108+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
109+
factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
110+
111+
t_real const beta = m_sigma * m_sigma;
112+
t_real const gamma = 0.0001;
113+
std::string tf_model_path = purify::models_directory() + "/snr_15_model_dynamic.onnx";
114+
115+
m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
116+
factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma,
117+
beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3,
118+
1e-2, 50, 1.0, tf_model_path, factory::g_proximal_type::TFGProximal);
119+
120+
while (state.KeepRunning()) {
121+
auto start = std::chrono::high_resolution_clock::now();
122+
(*m_fb)();
123+
auto end = std::chrono::high_resolution_clock::now();
124+
state.SetIterationTime(b_utilities::duration(start, end));
125+
}
126+
}
127+
128+
BENCHMARK_REGISTER_F(AlgoFixture, ForwardBackwardOnnx)
129+
//->Apply(b_utilities::Arguments)
130+
->Args({128, 10000, 4, 10})
131+
->UseManualTime()
132+
->MinTime(10.0)
133+
->MinWarmUpTime(5.0)
134+
->Repetitions(3) //->ReportAggregatesOnly(true)
135+
->Unit(benchmark::kMillisecond);
136+
#endif
137+
138+
BENCHMARK_REGISTER_F(AlgoFixture, Padmm)
139+
//->Apply(b_utilities::Arguments)
140+
->Args({128, 10000, 4, 10})
141+
->UseManualTime()
142+
->MinTime(10.0)
143+
->MinWarmUpTime(5.0)
144+
->Repetitions(3) //->ReportAggregatesOnly(true)
145+
->Unit(benchmark::kMillisecond);
146+
147+
BENCHMARK_REGISTER_F(AlgoFixture, ForwardBackward)
148+
//->Apply(b_utilities::Arguments)
149+
->Args({128, 10000, 4, 10})
150+
->UseManualTime()
151+
->MinTime(10.0)
152+
->MinWarmUpTime(5.0)
153+
->Repetitions(3) //->ReportAggregatesOnly(true)
154+
->Unit(benchmark::kMillisecond);
155+
156+
BENCHMARK_MAIN();

0 commit comments

Comments
 (0)