Skip to content

Commit 26ed6e1

Browse files
committed
MPI stochastic algorithm benchmark
1 parent 1b9265e commit 26ed6e1

File tree

3 files changed

+29
-46
lines changed

3 files changed

+29
-46
lines changed

cpp/benchmarks/CMakeLists.txt

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ target_link_libraries(fft benchmark libpurify)
1111
add_executable(degridding utilities.cc degridding.cc)
1212
target_link_libraries(degridding benchmark libpurify)
1313

14-
add_executable(stochastic_algorithm utilities.cc stochastic_algorithm.cc)
15-
target_link_libraries(stochastic_algorithm benchmark libpurify)
16-
17-
1814
# Skip ArrayFire benchmarks for now, add back later if needed
1915
# if(doaf)
2016
# add_benchmark(measurement_operator_af utilities.cc LIBRARIES libpurify)
@@ -25,14 +21,16 @@ target_link_libraries(stochastic_algorithm benchmark libpurify)
2521
if(dompi)
2622
add_executable(mpi_benchmark_MO main.cc utilities.cc measurement_operator_mpi.cc)
2723
target_link_libraries(mpi_benchmark_MO ${MPI_LIBRARIES} benchmark libpurify)
28-
#target_include_directories(mpi_benchmark_MO PUBLIC "${PROJECT_SOURCE_DIR}/cpp" "${CMAKE_CURRENT_BINARY_DIR}/include")
24+
2925
add_executable(mpi_benchmark_MO_wproj main.cc utilities.cc measurement_operator_wproj.cc)
3026
target_link_libraries(mpi_benchmark_MO_wproj ${MPI_LIBRARIES} benchmark libpurify)
31-
#target_include_directories(mpi_benchmark_MO_wproj PUBLIC "${PROJECT_SOURCE_DIR}/cpp" "${CMAKE_CURRENT_BINARY_DIR}/include")
27+
3228
add_executable(mpi_benchmark_algorithms main.cc utilities.cc algorithms_mpi.cc)
3329
target_link_libraries(mpi_benchmark_algorithms ${MPI_LIBRARIES} benchmark libpurify)
34-
#target_include_directories(mpi_benchmark_PADMM PUBLIC "${PROJECT_SOURCE_DIR}/cpp" "${CMAKE_CURRENT_BINARY_DIR}/include")
30+
3531
add_executable(mpi_benchmark_WLO main.cc utilities.cc wavelet_operator_mpi.cc)
3632
target_link_libraries(mpi_benchmark_WLO ${MPI_LIBRARIES} benchmark libpurify)
37-
#target_include_directories(mpi_benchmark_WLO PUBLIC "${PROJECT_SOURCE_DIR}/cpp" "${CMAKE_CURRENT_BINARY_DIR}/include")
33+
34+
add_executable(mpi_benchmark_stochastic_algorithm main.cc utilities.cc stochastic_algorithm.cc)
35+
target_link_libraries(mpi_benchmark_stochastic_algorithm ${MPI_LIBRARIES} benchmark libpurify)
3836
endif()

cpp/benchmarks/algorithms_mpi.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, PadmmDistributeImage)(benchmark::State &state
8989

9090
m_padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
9191
factory::algo_distribution::mpi_distributed, m_measurements_distribute_image, wavelets,
92-
m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
92+
m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true,
9393
false, 1e-3, 1e-2, 50, 1.0, 1.0);
9494

9595
// Benchmark the application of the algorithm
@@ -110,7 +110,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, PadmmDistributeGrid)(benchmark::State &state)
110110

111111
m_padmm = factory::padmm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
112112
factory::algo_distribution::mpi_distributed, m_measurements_distribute_grid, wavelets,
113-
m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
113+
m_uv_data, m_sigma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true,
114114
false, 1e-3, 1e-2, 50, 1.0, 1.0);
115115

116116
// Benchmark the application of the algorithm
@@ -134,7 +134,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbDistributeImage)(benchmark::State &state) {
134134

135135
m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
136136
factory::algo_distribution::mpi_serial, m_measurements_distribute_image, wavelets, m_uv_data,
137-
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
137+
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true,
138138
false, 1e-3, 1e-2, 50, 1.0);
139139

140140
// Benchmark the application of the algorithm
@@ -158,7 +158,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbDistributeGrid)(benchmark::State &state) {
158158

159159
m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
160160
factory::algo_distribution::mpi_serial, m_measurements_distribute_grid, wavelets, m_uv_data,
161-
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
161+
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true,
162162
false, 1e-3, 1e-2, 50, 1.0);
163163

164164
// Benchmark the application of the algorithm
@@ -187,7 +187,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbOnnxDistributeImage)(benchmark::State &stat
187187

188188
m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
189189
factory::algo_distribution::mpi_serial, m_measurements_distribute_image, wavelets, m_uv_data,
190-
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
190+
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true,
191191
false, 1e-3, 1e-2, 50, 1.0, tf_model_path, nondiff_func_type::Denoiser);
192192

193193
// Benchmark the application of the algorithm
@@ -208,7 +208,7 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbOnnxDistributeImage)
208208
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
209209
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
210210
->UseManualTime()
211-
->MinTime(120.0)
211+
->MinTime(60.0)
212212
->MinWarmUpTime(10.0)
213213
->Repetitions(3) //->ReportAggregatesOnly(true)
214214
->Unit(benchmark::kMillisecond);
@@ -223,7 +223,7 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbDistributeImage)
223223
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
224224
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
225225
->UseManualTime()
226-
->MinTime(120.0)
226+
->MinTime(60.0)
227227
->MinWarmUpTime(10.0)
228228
->Repetitions(3) //->ReportAggregatesOnly(true)
229229
->Unit(benchmark::kMillisecond);
@@ -236,7 +236,7 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbDistributeGrid)
236236
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
237237
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
238238
->UseManualTime()
239-
->MinTime(120.0)
239+
->MinTime(60.0)
240240
->MinWarmUpTime(10.0)
241241
->Repetitions(3) //->ReportAggregatesOnly(true)
242242
->Unit(benchmark::kMillisecond);

cpp/benchmarks/stochastic_algorithm.cc

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@
77
#include "purify/algorithm_factory.h"
88
#include "purify/directories.h"
99
#include "purify/measurement_operator_factory.h"
10+
#include "purify/mpi_utilities.h"
1011
#include "purify/operators.h"
1112
#include "purify/utilities.h"
12-
#include "purify/mpi_utilities.h"
1313
#include "purify/uvw_utilities.h"
1414
#include "purify/wavelet_operator_factory.h"
1515
#include <sopt/imaging_padmm.h>
16+
#include <sopt/mpi/communicator.h>
17+
#include <sopt/mpi/session.h>
1618
#include <sopt/power_method.h>
1719
#include <sopt/relative_variation.h>
1820
#include <sopt/utilities.h>
1921
#include <sopt/wavelets.h>
2022
#include <sopt/wavelets/sara.h>
21-
#include <sopt/mpi/communicator.h>
22-
#include <sopt/mpi/session.h>
2323

2424
#ifdef PURIFY_H5
2525
#include "purify/h5reader.h"
@@ -30,8 +30,6 @@ using namespace purify;
3030
class StochasticAlgoFixture : public ::benchmark::Fixture {
3131
public:
3232
void SetUp(const ::benchmark::State &state) {
33-
// m_uv_data = utilities::read_visibility(input_data_path, false);
34-
// m_uv_data.units = utilities::vis_units::radians;
3533

3634
m_imsizex = state.range(0);
3735
m_imsizey = state.range(0);
@@ -40,11 +38,11 @@ class StochasticAlgoFixture : public ::benchmark::Fixture {
4038
m_beta = m_sigma * m_sigma;
4139
m_gamma = 0.0001;
4240

43-
m_N = 1000;
41+
m_N = state.range(1);
4442

45-
// m_input_data_path = data_filename("expected/fb/input_data.vis");
46-
m_input_data_path = data_filename("ska_mid/uvw_ska1mid197_simulation_12h_dt_60.h5");
43+
m_input_data_path = data_filename("expected/fb/input_data.h5");
4744

45+
m_world = sopt::mpi::Communicator::World();
4846
}
4947

5048
void TearDown(const ::benchmark::State &state) {}
@@ -53,8 +51,6 @@ class StochasticAlgoFixture : public ::benchmark::Fixture {
5351

5452
std::string m_input_data_path;
5553

56-
// utilities::vis_params m_uv_data;
57-
5854
t_uint m_imsizey;
5955
t_uint m_imsizex;
6056

@@ -74,40 +70,31 @@ BENCHMARK_DEFINE_F(StochasticAlgoFixture, ForwardBackward)(benchmark::State &sta
7470
// This functor would be defined in Purify
7571
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
7672
[this]() {
77-
H5::H5Handler h5file(m_input_data_path, m_world);
78-
utilities::vis_params uv_data = H5::stochread_visibility(h5file, m_N, true);
73+
H5::H5Handler h5file(m_input_data_path, m_world);
74+
utilities::vis_params uv_data = H5::stochread_visibility(h5file, m_N, false);
7975
uv_data.units = utilities::vis_units::radians;
8076
auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
81-
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128, 128, 1,
82-
1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
77+
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, m_imsizex,
78+
m_imsizey, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
8379

8480
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
8581
};
8682

87-
Vector<t_complex> const init = Vector<t_complex>::Ones(m_imsizex * m_imsizey);
88-
89-
PURIFY_INFO("Call random_updater");
90-
9183
auto IS = random_updater();
9284
auto Phi = IS->Phi();
9385

94-
PURIFY_INFO("Call power method");
86+
auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
87+
Phi, 1000, 1e-5, m_world.broadcast(Vector<t_complex>::Ones(m_imsizex * m_imsizey).eval()));
9588

96-
auto const power_method_stuff =
97-
sopt::algorithm::power_method<Vector<t_complex>>(Phi, 1000, 1e-5, m_world.broadcast(init.eval()));
9889
const t_real op_norm = std::get<0>(power_method_stuff);
9990

100-
PURIFY_INFO("Construct wavelets");
101-
10291
// wavelets
10392
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
10493
factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
10594

106-
PURIFY_INFO("Construct fb algorithm with random updater");
107-
10895
// algorithm
10996
sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
110-
fb.itermax(state.range(1))
97+
fb.itermax(state.range(2))
11198
.step_size(m_beta * sqrt(2))
11299
.sigma(m_sigma * sqrt(2))
113100
.regulariser_strength(m_gamma)
@@ -137,11 +124,9 @@ BENCHMARK_DEFINE_F(StochasticAlgoFixture, ForwardBackward)(benchmark::State &sta
137124
}
138125

139126
BENCHMARK_REGISTER_F(StochasticAlgoFixture, ForwardBackward)
140-
->Args({128, 10})
127+
->Args({128, 10000, 10})
141128
->UseManualTime()
142-
->MinTime(10.0)
129+
->MinTime(60.0)
143130
->MinWarmUpTime(5.0)
144131
->Repetitions(3) //->ReportAggregatesOnly(true)
145132
->Unit(benchmark::kMillisecond);
146-
147-
BENCHMARK_MAIN();

0 commit comments

Comments
 (0)