Skip to content

Commit 1b9265e

Browse files
committed
Fix file name string bug, use mpi and hdf5 to stochastic read vis file
1 parent b534c22 commit 1b9265e

File tree

1 file changed

+62
-83
lines changed

1 file changed

+62
-83
lines changed

cpp/benchmarks/stochastic_algorithm.cc

Lines changed: 62 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,27 @@
99
#include "purify/measurement_operator_factory.h"
1010
#include "purify/operators.h"
1111
#include "purify/utilities.h"
12+
#include "purify/mpi_utilities.h"
1213
#include "purify/uvw_utilities.h"
1314
#include "purify/wavelet_operator_factory.h"
1415
#include <sopt/imaging_padmm.h>
16+
#include <sopt/power_method.h>
1517
#include <sopt/relative_variation.h>
1618
#include <sopt/utilities.h>
1719
#include <sopt/wavelets.h>
1820
#include <sopt/wavelets/sara.h>
19-
#include <sopt/power_method.h>
20-
21-
#include "purify/test_data.h"
21+
#include <sopt/mpi/communicator.h>
22+
#include <sopt/mpi/session.h>
2223

24+
#ifdef PURIFY_H5
25+
#include "purify/h5reader.h"
26+
#endif
2327

2428
using namespace purify;
2529

2630
class StochasticAlgoFixture : public ::benchmark::Fixture {
2731
public:
28-
2932
void SetUp(const ::benchmark::State &state) {
30-
3133
// m_uv_data = utilities::read_visibility(input_data_path, false);
3234
// m_uv_data.units = utilities::vis_units::radians;
3335

@@ -39,16 +41,20 @@ class StochasticAlgoFixture : public ::benchmark::Fixture {
3941
m_gamma = 0.0001;
4042

4143
m_N = 1000;
42-
44+
45+
// m_input_data_path = data_filename("expected/fb/input_data.vis");
46+
m_input_data_path = data_filename("ska_mid/uvw_ska1mid197_simulation_12h_dt_60.h5");
47+
4348
}
4449

4550
void TearDown(const ::benchmark::State &state) {}
4651

47-
//const std::string &input_data_path = data_filename("ska_mid/uvw_ska1mid197_simulation_12h_dt_60.h5");
48-
const std::string &m_input_data_path = data_filename("expected/fb/input_data.vis");
52+
sopt::mpi::Communicator m_world;
53+
54+
std::string m_input_data_path;
55+
56+
// utilities::vis_params m_uv_data;
4957

50-
//utilities::vis_params m_uv_data;
51-
5258
t_uint m_imsizey;
5359
t_uint m_imsizex;
5460

@@ -57,104 +63,77 @@ class StochasticAlgoFixture : public ::benchmark::Fixture {
5763
t_real m_gamma;
5864

5965
size_t m_N;
60-
66+
6167
std::vector<std::tuple<std::string, t_uint>> const m_sara{
6268
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
6369
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
6470
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
65-
6671
};
6772

6873
BENCHMARK_DEFINE_F(StochasticAlgoFixture, ForwardBackward)(benchmark::State &state) {
69-
70-
// This functor would be defined in Purify
71-
std::mt19937 rng(0);
72-
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
73-
[this, &rng]() {
74-
utilities::vis_params uv_data = utilities::read_visibility(m_input_data_path, false);
74+
// This functor would be defined in Purify
75+
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
76+
[this]() {
77+
H5::H5Handler h5file(m_input_data_path, m_world);
78+
utilities::vis_params uv_data = H5::stochread_visibility(h5file, m_N, true);
7579
uv_data.units = utilities::vis_units::radians;
76-
77-
// Get random subset
78-
std::vector<size_t> indices(uv_data.size());
79-
size_t i = 0;
80-
for (auto &x : indices) {
81-
x = i++;
82-
}
83-
84-
std::shuffle(indices.begin(), indices.end(), rng);
85-
Vector<t_real> u_fragment(m_N);
86-
Vector<t_real> v_fragment(m_N);
87-
Vector<t_real> w_fragment(m_N);
88-
Vector<t_complex> vis_fragment(m_N);
89-
Vector<t_complex> weights_fragment(m_N);
90-
for (i = 0; i < m_N; i++) {
91-
size_t j = indices[i];
92-
u_fragment[i] = uv_data.u[j];
93-
v_fragment[i] = uv_data.v[j];
94-
w_fragment[i] = uv_data.w[j];
95-
vis_fragment[i] = uv_data.vis[j];
96-
weights_fragment[i] = uv_data.weights[j];
97-
}
98-
utilities::vis_params uv_data_fragment(u_fragment, v_fragment, w_fragment, vis_fragment,
99-
weights_fragment, uv_data.units, uv_data.ra,
100-
uv_data.dec, uv_data.average_frequency);
101-
10280
auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
103-
factory::distributed_measurement_operator::serial, uv_data_fragment, m_imsizey, m_imsizex,
104-
1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
81+
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128, 128, 1,
82+
1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
10583

106-
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data_fragment.vis, phi);
84+
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
10785
};
108-
109-
Vector<t_complex> const init = Vector<t_complex>::Ones(m_imsizex * m_imsizey);
110-
111-
PURIFY_INFO("Call random_updater");
112-
113-
auto IS = random_updater();
114-
auto Phi = IS->Phi();
115-
116-
PURIFY_INFO("Call power method");
117-
118-
auto const power_method_stuff =
119-
sopt::algorithm::power_method<Vector<t_complex>>(Phi, 1000, 1e-5, init);
120-
const t_real op_norm = std::get<0>(power_method_stuff);
121-
122-
PURIFY_INFO("Construct wavelets");
123-
124-
// wavelets
125-
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
86+
87+
Vector<t_complex> const init = Vector<t_complex>::Ones(m_imsizex * m_imsizey);
88+
89+
PURIFY_INFO("Call random_updater");
90+
91+
auto IS = random_updater();
92+
auto Phi = IS->Phi();
93+
94+
PURIFY_INFO("Call power method");
95+
96+
auto const power_method_stuff =
97+
sopt::algorithm::power_method<Vector<t_complex>>(Phi, 1000, 1e-5, m_world.broadcast(init.eval()));
98+
const t_real op_norm = std::get<0>(power_method_stuff);
99+
100+
PURIFY_INFO("Construct wavelets");
101+
102+
// wavelets
103+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
126104
factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
127105

128-
PURIFY_INFO("Construct fb algorithm with random updater");
129-
130-
// algorithm
131-
sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
132-
fb.itermax(state.range(1))
106+
PURIFY_INFO("Construct fb algorithm with random updater");
107+
108+
// algorithm
109+
sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
110+
fb.itermax(state.range(1))
133111
.step_size(m_beta * sqrt(2))
134112
.sigma(m_sigma * sqrt(2))
135113
.regulariser_strength(m_gamma)
136114
.relative_variation(1e-3)
137115
.residual_tolerance(0)
138116
.tight_frame(true)
139-
.sq_op_norm(op_norm * op_norm);
117+
.sq_op_norm(op_norm * op_norm)
118+
.obj_comm(m_world);
140119

141-
auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
142-
gp->l1_proximal_tolerance(1e-4)
120+
auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
121+
gp->l1_proximal_tolerance(1e-4)
143122
.l1_proximal_nu(1)
144123
.l1_proximal_itermax(50)
145124
.l1_proximal_positivity_constraint(true)
146125
.l1_proximal_real_constraint(true)
147126
.Psi(*wavelets);
148-
fb.g_function(gp);
149-
150-
PURIFY_INFO("Start iteration loop");
151-
152-
while (state.KeepRunning()) {
153-
auto start = std::chrono::high_resolution_clock::now();
154-
fb();
155-
auto end = std::chrono::high_resolution_clock::now();
156-
state.SetIterationTime(b_utilities::duration(start, end));
157-
}
127+
fb.g_function(gp);
128+
129+
PURIFY_INFO("Start iteration loop");
130+
131+
while (state.KeepRunning()) {
132+
auto start = std::chrono::high_resolution_clock::now();
133+
fb();
134+
auto end = std::chrono::high_resolution_clock::now();
135+
state.SetIterationTime(b_utilities::duration(start, end, m_world));
136+
}
158137
}
159138

160139
BENCHMARK_REGISTER_F(StochasticAlgoFixture, ForwardBackward)

0 commit comments

Comments
 (0)