Skip to content

Commit 7779be2

Browse files
authored
Merge pull request #388 from astro-informatics/mm/ApproxNormedStochasticBenchmark
Use static variables to avoid recalculating the norm
2 parents 19bf277 + c1a256a commit 7779be2

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

cpp/benchmarks/stochastic_algorithm.cc

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,76 @@ BENCHMARK_DEFINE_F(StochasticAlgoFixture, ForwardBackward)(benchmark::State &sta
120120
}
121121
}
122122

123+
BENCHMARK_DEFINE_F(StochasticAlgoFixture, ForwardBackwardApproxNorm)(benchmark::State &state) {
124+
// This functor would be defined in Purify
125+
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
126+
[this]() {
127+
H5::H5Handler h5file(m_input_data_path, m_world);
128+
utilities::vis_params uv_data = H5::stochread_visibility(h5file, m_N, false);
129+
uv_data.units = utilities::vis_units::radians;
130+
auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
131+
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, m_imsizex,
132+
m_imsizey, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
133+
134+
// declaration of static variables to avoid recalculating the normalisation
135+
static auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
136+
*phi, 1000, 1e-5,
137+
m_world.broadcast(Vector<t_complex>::Ones(m_imsizex * m_imsizey).eval()));
138+
139+
static const t_real op_norm = std::get<0>(power_method_stuff);
140+
141+
// set the normalisation of the new phi
142+
phi->set_norm(op_norm);
143+
144+
return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
145+
};
146+
147+
// wavelets
148+
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
149+
factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);
150+
151+
// algorithm
152+
sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
153+
fb.itermax(state.range(2))
154+
.step_size(m_beta * sqrt(2))
155+
.sigma(m_sigma * sqrt(2))
156+
.regulariser_strength(m_gamma)
157+
.relative_variation(1e-3)
158+
.residual_tolerance(0)
159+
.tight_frame(true)
160+
.obj_comm(m_world);
161+
162+
auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
163+
gp->l1_proximal_tolerance(1e-4)
164+
.l1_proximal_nu(1)
165+
.l1_proximal_itermax(50)
166+
.l1_proximal_positivity_constraint(true)
167+
.l1_proximal_real_constraint(true)
168+
.Psi(*wavelets);
169+
fb.g_function(gp);
170+
171+
PURIFY_INFO("Start iteration loop");
172+
173+
while (state.KeepRunning()) {
174+
auto start = std::chrono::high_resolution_clock::now();
175+
fb();
176+
auto end = std::chrono::high_resolution_clock::now();
177+
state.SetIterationTime(b_utilities::duration(start, end, m_world));
178+
}
179+
}
180+
123181
BENCHMARK_REGISTER_F(StochasticAlgoFixture, ForwardBackward)
124182
->Args({128, 10000, 10})
125183
->UseManualTime()
126184
->MinTime(60.0)
127185
->MinWarmUpTime(5.0)
128186
->Repetitions(3) //->ReportAggregatesOnly(true)
129187
->Unit(benchmark::kMillisecond);
188+
189+
BENCHMARK_REGISTER_F(StochasticAlgoFixture, ForwardBackwardApproxNorm)
190+
->Args({128, 10000, 10})
191+
->UseManualTime()
192+
->MinTime(60.0)
193+
->MinWarmUpTime(5.0)
194+
->Repetitions(3) //->ReportAggregatesOnly(true)
195+
->Unit(benchmark::kMillisecond);

0 commit comments

Comments
 (0)