9
9
#include " purify/measurement_operator_factory.h"
10
10
#include " purify/operators.h"
11
11
#include " purify/utilities.h"
12
+ #include " purify/mpi_utilities.h"
12
13
#include " purify/uvw_utilities.h"
13
14
#include " purify/wavelet_operator_factory.h"
14
15
#include < sopt/imaging_padmm.h>
16
+ #include < sopt/power_method.h>
15
17
#include < sopt/relative_variation.h>
16
18
#include < sopt/utilities.h>
17
19
#include < sopt/wavelets.h>
18
20
#include < sopt/wavelets/sara.h>
19
- #include < sopt/power_method.h>
20
-
21
- #include " purify/test_data.h"
21
+ #include < sopt/mpi/communicator.h>
22
+ #include < sopt/mpi/session.h>
22
23
24
+ #ifdef PURIFY_H5
25
+ #include " purify/h5reader.h"
26
+ #endif
23
27
24
28
using namespace purify ;
25
29
26
30
class StochasticAlgoFixture : public ::benchmark::Fixture {
27
31
public:
28
-
29
32
void SetUp (const ::benchmark::State &state) {
30
-
31
33
// m_uv_data = utilities::read_visibility(input_data_path, false);
32
34
// m_uv_data.units = utilities::vis_units::radians;
33
35
@@ -39,16 +41,20 @@ class StochasticAlgoFixture : public ::benchmark::Fixture {
39
41
m_gamma = 0.0001 ;
40
42
41
43
m_N = 1000 ;
42
-
44
+
45
+ // m_input_data_path = data_filename("expected/fb/input_data.vis");
46
+ m_input_data_path = data_filename (" ska_mid/uvw_ska1mid197_simulation_12h_dt_60.h5" );
47
+
43
48
}
44
49
45
50
void TearDown (const ::benchmark::State &state) {}
46
51
47
- // const std::string &input_data_path = data_filename("ska_mid/uvw_ska1mid197_simulation_12h_dt_60.h5");
48
- const std::string &m_input_data_path = data_filename(" expected/fb/input_data.vis" );
52
+ sopt::mpi::Communicator m_world;
53
+
54
+ std::string m_input_data_path;
55
+
56
+ // utilities::vis_params m_uv_data;
49
57
50
- // utilities::vis_params m_uv_data;
51
-
52
58
t_uint m_imsizey;
53
59
t_uint m_imsizex;
54
60
@@ -57,104 +63,77 @@ class StochasticAlgoFixture : public ::benchmark::Fixture {
57
63
t_real m_gamma;
58
64
59
65
size_t m_N;
60
-
66
+
61
67
std::vector<std::tuple<std::string, t_uint>> const m_sara{
62
68
std::make_tuple (" Dirac" , 3u ), std::make_tuple (" DB1" , 3u ), std::make_tuple (" DB2" , 3u ),
63
69
std::make_tuple (" DB3" , 3u ), std::make_tuple (" DB4" , 3u ), std::make_tuple (" DB5" , 3u ),
64
70
std::make_tuple (" DB6" , 3u ), std::make_tuple (" DB7" , 3u ), std::make_tuple (" DB8" , 3u )};
65
-
66
71
};
67
72
68
73
BENCHMARK_DEFINE_F (StochasticAlgoFixture, ForwardBackward)(benchmark::State &state) {
69
-
70
- // This functor would be defined in Purify
71
- std::mt19937 rng (0 );
72
- std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
73
- [this , &rng]() {
74
- utilities::vis_params uv_data = utilities::read_visibility (m_input_data_path, false );
74
+ // This functor would be defined in Purify
75
+ std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
76
+ [this ]() {
77
+ H5::H5Handler h5file (m_input_data_path, m_world);
78
+ utilities::vis_params uv_data = H5::stochread_visibility (h5file, m_N, true );
75
79
uv_data.units = utilities::vis_units::radians;
76
-
77
- // Get random subset
78
- std::vector<size_t > indices (uv_data.size ());
79
- size_t i = 0 ;
80
- for (auto &x : indices) {
81
- x = i++;
82
- }
83
-
84
- std::shuffle (indices.begin (), indices.end (), rng);
85
- Vector<t_real> u_fragment (m_N);
86
- Vector<t_real> v_fragment (m_N);
87
- Vector<t_real> w_fragment (m_N);
88
- Vector<t_complex> vis_fragment (m_N);
89
- Vector<t_complex> weights_fragment (m_N);
90
- for (i = 0 ; i < m_N; i++) {
91
- size_t j = indices[i];
92
- u_fragment[i] = uv_data.u [j];
93
- v_fragment[i] = uv_data.v [j];
94
- w_fragment[i] = uv_data.w [j];
95
- vis_fragment[i] = uv_data.vis [j];
96
- weights_fragment[i] = uv_data.weights [j];
97
- }
98
- utilities::vis_params uv_data_fragment (u_fragment, v_fragment, w_fragment, vis_fragment,
99
- weights_fragment, uv_data.units , uv_data.ra ,
100
- uv_data.dec , uv_data.average_frequency );
101
-
102
80
auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
103
- factory::distributed_measurement_operator::serial, uv_data_fragment, m_imsizey, m_imsizex ,
104
- 1 , 1 , 2 , kernels::kernel_from_string.at (" kb" ), 4 , 4 );
81
+ factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128 , 128 , 1 ,
82
+ 1 , 2 , kernels::kernel_from_string.at (" kb" ), 4 , 4 );
105
83
106
- return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data_fragment .vis , phi);
84
+ return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data .vis , phi);
107
85
};
108
-
109
- Vector<t_complex> const init = Vector<t_complex>::Ones (m_imsizex * m_imsizey);
110
-
111
- PURIFY_INFO (" Call random_updater" );
112
-
113
- auto IS = random_updater ();
114
- auto Phi = IS->Phi ();
115
-
116
- PURIFY_INFO (" Call power method" );
117
-
118
- auto const power_method_stuff =
119
- sopt::algorithm::power_method<Vector<t_complex>>(Phi, 1000 , 1e-5 , init);
120
- const t_real op_norm = std::get<0 >(power_method_stuff);
121
-
122
- PURIFY_INFO (" Construct wavelets" );
123
-
124
- // wavelets
125
- auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
86
+
87
+ Vector<t_complex> const init = Vector<t_complex>::Ones (m_imsizex * m_imsizey);
88
+
89
+ PURIFY_INFO (" Call random_updater" );
90
+
91
+ auto IS = random_updater ();
92
+ auto Phi = IS->Phi ();
93
+
94
+ PURIFY_INFO (" Call power method" );
95
+
96
+ auto const power_method_stuff =
97
+ sopt::algorithm::power_method<Vector<t_complex>>(Phi, 1000 , 1e-5 , m_world. broadcast ( init. eval ()) );
98
+ const t_real op_norm = std::get<0 >(power_method_stuff);
99
+
100
+ PURIFY_INFO (" Construct wavelets" );
101
+
102
+ // wavelets
103
+ auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
126
104
factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
127
105
128
- PURIFY_INFO (" Construct fb algorithm with random updater" );
129
-
130
- // algorithm
131
- sopt::algorithm::ImagingForwardBackward<t_complex> fb (random_updater);
132
- fb.itermax (state.range (1 ))
106
+ PURIFY_INFO (" Construct fb algorithm with random updater" );
107
+
108
+ // algorithm
109
+ sopt::algorithm::ImagingForwardBackward<t_complex> fb (random_updater);
110
+ fb.itermax (state.range (1 ))
133
111
.step_size (m_beta * sqrt (2 ))
134
112
.sigma (m_sigma * sqrt (2 ))
135
113
.regulariser_strength (m_gamma)
136
114
.relative_variation (1e-3 )
137
115
.residual_tolerance (0 )
138
116
.tight_frame (true )
139
- .sq_op_norm (op_norm * op_norm);
117
+ .sq_op_norm (op_norm * op_norm)
118
+ .obj_comm (m_world);
140
119
141
- auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false );
142
- gp->l1_proximal_tolerance (1e-4 )
120
+ auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false );
121
+ gp->l1_proximal_tolerance (1e-4 )
143
122
.l1_proximal_nu (1 )
144
123
.l1_proximal_itermax (50 )
145
124
.l1_proximal_positivity_constraint (true )
146
125
.l1_proximal_real_constraint (true )
147
126
.Psi (*wavelets);
148
- fb.g_function (gp);
149
-
150
- PURIFY_INFO (" Start iteration loop" );
151
-
152
- while (state.KeepRunning ()) {
153
- auto start = std::chrono::high_resolution_clock::now ();
154
- fb ();
155
- auto end = std::chrono::high_resolution_clock::now ();
156
- state.SetIterationTime (b_utilities::duration (start, end));
157
- }
127
+ fb.g_function (gp);
128
+
129
+ PURIFY_INFO (" Start iteration loop" );
130
+
131
+ while (state.KeepRunning ()) {
132
+ auto start = std::chrono::high_resolution_clock::now ();
133
+ fb ();
134
+ auto end = std::chrono::high_resolution_clock::now ();
135
+ state.SetIterationTime (b_utilities::duration (start, end, m_world ));
136
+ }
158
137
}
159
138
160
139
BENCHMARK_REGISTER_F (StochasticAlgoFixture, ForwardBackward)
0 commit comments