Skip to content

Commit 24dafd5

Browse files
authored
Merge pull request #384 from astro-informatics/tk/stochastic_algorithm_benchmark
Add Stochastic Algorithm Benchmark
2 parents a639d10 + 0d5acee commit 24dafd5

File tree

3 files changed

+149
-15
lines changed

3 files changed

+149
-15
lines changed

cpp/benchmarks/CMakeLists.txt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,19 @@ target_link_libraries(degridding benchmark::benchmark libpurify)
2525
if(dompi)
2626
add_executable(mpi_benchmark_MO main.cc utilities.cc measurement_operator_mpi.cc)
2727
target_link_libraries(mpi_benchmark_MO ${MPI_LIBRARIES} benchmark::benchmark libpurify)
28-
#target_include_directories(mpi_benchmark_MO PUBLIC "${PROJECT_SOURCE_DIR}/cpp" "${CMAKE_CURRENT_BINARY_DIR}/include")
28+
2929
add_executable(mpi_benchmark_MO_wproj main.cc utilities.cc measurement_operator_wproj.cc)
3030
target_link_libraries(mpi_benchmark_MO_wproj ${MPI_LIBRARIES} benchmark::benchmark libpurify)
31-
#target_include_directories(mpi_benchmark_MO_wproj PUBLIC "${PROJECT_SOURCE_DIR}/cpp" "${CMAKE_CURRENT_BINARY_DIR}/include")
31+
3232
add_executable(mpi_benchmark_algorithms main.cc utilities.cc algorithms_mpi.cc)
3333
target_link_libraries(mpi_benchmark_algorithms ${MPI_LIBRARIES} benchmark::benchmark libpurify)
34-
#target_include_directories(mpi_benchmark_PADMM PUBLIC "${PROJECT_SOURCE_DIR}/cpp" "${CMAKE_CURRENT_BINARY_DIR}/include")
34+
3535
add_executable(mpi_benchmark_WLO main.cc utilities.cc wavelet_operator_mpi.cc)
3636
target_link_libraries(mpi_benchmark_WLO ${MPI_LIBRARIES} benchmark::benchmark libpurify)
37-
#target_include_directories(mpi_benchmark_WLO PUBLIC "${PROJECT_SOURCE_DIR}/cpp" "${CMAKE_CURRENT_BINARY_DIR}/include")
37+
38+
if(hdf5)
39+
add_executable(mpi_benchmark_stochastic main.cc utilities.cc stochastic_algorithm.cc)
40+
target_link_libraries(mpi_benchmark_stochastic ${MPI_LIBRARIES} benchmark::benchmark libpurify)
41+
endif()
42+
3843
endif()

cpp/benchmarks/algorithms_mpi.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbDistributeImage)(benchmark::State &state) {
134134

135135
m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
136136
factory::algo_distribution::mpi_serial, m_measurements_distribute_image, wavelets, m_uv_data,
137-
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
138-
false, 1e-3, 1e-2, 50);
137+
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true, false,
138+
1e-3, 1e-2, 50);
139139

140140
// Benchmark the application of the algorithm
141141
while (state.KeepRunning()) {
@@ -158,8 +158,8 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbDistributeGrid)(benchmark::State &state) {
158158

159159
m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
160160
factory::algo_distribution::mpi_serial, m_measurements_distribute_grid, wavelets, m_uv_data,
161-
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
162-
false, 1e-3, 1e-2, 50);
161+
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true, false,
162+
1e-3, 1e-2, 50);
163163

164164
// Benchmark the application of the algorithm
165165
while (state.KeepRunning()) {
@@ -208,7 +208,7 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbOnnxDistributeImage)
208208
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
209209
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
210210
->UseManualTime()
211-
->MinTime(120.0)
211+
->MinTime(60.0)
212212
->MinWarmUpTime(10.0)
213213
->Repetitions(3) //->ReportAggregatesOnly(true)
214214
->Unit(benchmark::kMillisecond);
@@ -223,7 +223,7 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbDistributeImage)
223223
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
224224
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
225225
->UseManualTime()
226-
->MinTime(120.0)
226+
->MinTime(60.0)
227227
->MinWarmUpTime(10.0)
228228
->Repetitions(3) //->ReportAggregatesOnly(true)
229229
->Unit(benchmark::kMillisecond);
@@ -233,10 +233,10 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbDistributeGrid)
233233
->Args({128, 10000, 4, 10, 2})
234234
->Args({1024, static_cast<t_int>(1e6), 4, 10, 2})
235235
->Args({1024, static_cast<t_int>(1e7), 4, 10, 2})
236-
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
237-
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
236+
->Args({1024, static_cast<t_int>(1e8), 4, 10, 2})
237+
->Args({1024, static_cast<t_int>(1e9), 4, 10, 2})
238238
->UseManualTime()
239-
->MinTime(120.0)
239+
->MinTime(60.0)
240240
->MinWarmUpTime(10.0)
241241
->Repetitions(3) //->ReportAggregatesOnly(true)
242242
->Unit(benchmark::kMillisecond);
@@ -259,8 +259,8 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, PadmmDistributeGrid)
259259
->Args({128, 10000, 4, 10, 2})
260260
->Args({1024, static_cast<t_int>(1e6), 4, 10, 2})
261261
->Args({1024, static_cast<t_int>(1e7), 4, 10, 2})
262-
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
263-
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
262+
->Args({1024, static_cast<t_int>(1e8), 4, 10, 2})
263+
->Args({1024, static_cast<t_int>(1e9), 4, 10, 2})
264264
->UseManualTime()
265265
->MinTime(120.0)
266266
->MinWarmUpTime(10.0)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#include "purify/config.h"
2+
#include "purify/types.h"
3+
#include <array>
4+
#include <random>
5+
#include <benchmark/benchmark.h>
6+
#include "benchmarks/utilities.h"
7+
#include "purify/algorithm_factory.h"
8+
#include "purify/directories.h"
9+
#include "purify/measurement_operator_factory.h"
10+
#include "purify/mpi_utilities.h"
11+
#include "purify/operators.h"
12+
#include "purify/utilities.h"
13+
#include "purify/uvw_utilities.h"
14+
#include "purify/wavelet_operator_factory.h"
15+
#include <sopt/imaging_padmm.h>
16+
#include <sopt/mpi/communicator.h>
17+
#include <sopt/mpi/session.h>
18+
#include <sopt/power_method.h>
19+
#include <sopt/relative_variation.h>
20+
#include <sopt/utilities.h>
21+
#include <sopt/wavelets.h>
22+
#include <sopt/wavelets/sara.h>
23+
24+
#ifdef PURIFY_H5
25+
#include "purify/h5reader.h"
26+
#endif
27+
28+
using namespace purify;
29+
30+
class StochasticAlgoFixture : public ::benchmark::Fixture {
31+
public:
32+
void SetUp(const ::benchmark::State &state) {
33+
m_imsizex = state.range(0);
34+
m_imsizey = state.range(0);
35+
36+
m_sigma = 0.016820222945913496 * std::sqrt(2);
37+
m_beta = m_sigma * m_sigma;
38+
m_gamma = 0.0001;
39+
40+
m_N = state.range(1);
41+
42+
m_input_data_path = data_filename("expected/fb/input_data.h5");
43+
44+
m_world = sopt::mpi::Communicator::World();
45+
}
46+
47+
void TearDown(const ::benchmark::State &state) {}
48+
49+
sopt::mpi::Communicator m_world;
50+
51+
std::string m_input_data_path;
52+
53+
t_uint m_imsizey;
54+
t_uint m_imsizex;
55+
56+
t_real m_sigma;
57+
t_real m_beta;
58+
t_real m_gamma;
59+
60+
size_t m_N;
61+
62+
std::vector<std::tuple<std::string, t_uint>> const m_sara{
63+
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
64+
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
65+
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
66+
};
67+
68+
BENCHMARK_DEFINE_F(StochasticAlgoFixture, ForwardBackward)(benchmark::State &state) {
69+
// This functor would be defined in Purify
70+
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
71+
[this]() {
72+
H5::H5Handler h5file(m_input_data_path, m_world);
73+
utilities::vis_params uv_data = H5::stochread_visibility(h5file, m_N, false);
74+
uv_data.units = utilities::vis_units::radians;
75+
auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
76+
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, m_imsizex,
77+
m_imsizey, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
78+
79+
auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
80+
*phi, 1000, 1e-5,
81+
m_world.broadcast(Vector<t_complex>::Ones(m_imsizex * m_imsizey).eval()));
82+
83+
const t_real op_norm = std::get<0>(power_method_stuff);
84+
phi->set_norm(op_norm);
85+
86+
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
87+
};
88+
89+
// wavelets
90+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
91+
factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
92+
93+
// algorithm
94+
sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
95+
fb.itermax(state.range(2))
96+
.step_size(m_beta * sqrt(2))
97+
.sigma(m_sigma * sqrt(2))
98+
.regulariser_strength(m_gamma)
99+
.relative_variation(1e-3)
100+
.residual_tolerance(0)
101+
.tight_frame(true)
102+
.obj_comm(m_world);
103+
104+
auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
105+
gp->l1_proximal_tolerance(1e-4)
106+
.l1_proximal_nu(1)
107+
.l1_proximal_itermax(50)
108+
.l1_proximal_positivity_constraint(true)
109+
.l1_proximal_real_constraint(true)
110+
.Psi(*wavelets);
111+
fb.g_function(gp);
112+
113+
PURIFY_INFO("Start iteration loop");
114+
115+
while (state.KeepRunning()) {
116+
auto start = std::chrono::high_resolution_clock::now();
117+
fb();
118+
auto end = std::chrono::high_resolution_clock::now();
119+
state.SetIterationTime(b_utilities::duration(start, end, m_world));
120+
}
121+
}
122+
123+
BENCHMARK_REGISTER_F(StochasticAlgoFixture, ForwardBackward)
124+
->Args({128, 10000, 10})
125+
->UseManualTime()
126+
->MinTime(60.0)
127+
->MinWarmUpTime(5.0)
128+
->Repetitions(3) //->ReportAggregatesOnly(true)
129+
->Unit(benchmark::kMillisecond);

0 commit comments

Comments
 (0)