Skip to content

Add ONNX benchmark #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cmake_files/dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ else()
endif()

find_package(CFitsIO REQUIRED)
find_package(yaml-cpp REQUIRED)

if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.30.0")
cmake_policy(SET CMP0167 NEW)
endif()
find_package(Boost COMPONENTS system filesystem REQUIRED)

find_package(yaml-cpp REQUIRED)

find_package(sopt REQUIRED)
set(PURIFY_ONNXRT FALSE)
if (onnxrt)
Expand Down
77 changes: 65 additions & 12 deletions cpp/benchmarks/algorithms_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,36 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbOnnxDistributeImage)(benchmark::State &stat

m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
factory::algo_distribution::mpi_serial, m_measurements_distribute_image, wavelets, m_uv_data,
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
false, 1e-3, 1e-2, 50, tf_model_path, nondiff_func_type::Denoiser);
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true, false,
1e-3, 1e-2, 50, tf_model_path, nondiff_func_type::Denoiser);

// Benchmark the application of the algorithm
while (state.KeepRunning()) {
auto start = std::chrono::high_resolution_clock::now();
auto result = (*m_fb)();
auto end = std::chrono::high_resolution_clock::now();
std::cout << "Converged? " << result.good << " , niters = " << result.niters << std::endl;
state.SetIterationTime(b_utilities::duration(start, end, m_world));
}
}

BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbOnnxDistributeGrid)(benchmark::State &state) {
// Create the algorithm - has to be done there to reset the internal state.
// If done in the fixture repeats would start at the solution and converge immediately.

// TODO: Wavelets are constructed but not used in the factory method
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);

t_real const beta = m_sigma * m_sigma;
t_real const gamma = 0.0001;

std::string tf_model_path = purify::models_directory() + "/snr_15_model_dynamic.onnx";

m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
factory::algo_distribution::mpi_serial, m_measurements_distribute_grid, wavelets, m_uv_data,
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true, false,
1e-3, 1e-2, 50, tf_model_path, nondiff_func_type::Denoiser);
Comment on lines +216 to +219
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the measurement operator normalisation calculated anywhere? It doesn't appear to be in AlgoFixtureMPI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No? I'd assume it just uses the default 1.0


// Benchmark the application of the algorithm
while (state.KeepRunning()) {
Expand All @@ -205,23 +233,42 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbOnnxDistributeImage)
->Args({128, 10000, 4, 10, 1})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 1})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 1})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 1})
->UseManualTime()
->MinTime(120.0)
->MinWarmUpTime(10.0)
->Repetitions(3) //->ReportAggregatesOnly(true)
->Unit(benchmark::kMillisecond);

BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbOnnxDistributeGrid)
//->Apply(b_utilities::Arguments)
->Args({128, 10000, 4, 10, 1})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 2})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 2})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 2})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 2})
->UseManualTime()
->MinTime(9.0)
->MinWarmUpTime(1.0)
->Repetitions(3) //->ReportAggregatesOnly(true)
->Unit(benchmark::kMillisecond);

#endif

BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbDistributeImage)
//->Apply(b_utilities::Arguments)
->Args({128, 10000, 4, 10, 1})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 1})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 1})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 1})
->UseManualTime()
->MinTime(120.0)
->MinWarmUpTime(10.0)
Expand All @@ -233,8 +280,10 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbDistributeGrid)
->Args({128, 10000, 4, 10, 2})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 2})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 2})
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 2})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 2})
->UseManualTime()
->MinTime(120.0)
->MinWarmUpTime(10.0)
Expand All @@ -246,8 +295,10 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, PadmmDistributeImage)
->Args({128, 10000, 4, 10, 1})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 2})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 2})
->UseManualTime()
->MinTime(120.0)
->MinWarmUpTime(10.0)
Expand All @@ -259,8 +310,10 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, PadmmDistributeGrid)
->Args({128, 10000, 4, 10, 2})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 2})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 2})
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 2})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 2})
->UseManualTime()
->MinTime(120.0)
->MinWarmUpTime(10.0)
Expand Down
6 changes: 3 additions & 3 deletions cpp/benchmarks/utilities.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ utilities::vis_params random_measurements(t_int size, const t_real max_w, const
std::ifstream vis_file_str(vis_file);

utilities::vis_params uv_data;
if (vis_file_str.good()) {
if (false) {
PURIFY_INFO("Reading random visibilities from file {}", vis_file);
uv_data = utilities::read_visibility(vis_file, true);
uv_data.units = utilities::vis_units::radians;
} else {
PURIFY_INFO("Generating random visibilities and writing to {}", vis_file);
PURIFY_INFO("Generating random visibilities");
t_real const sigma_m = constant::pi / 3;
uv_data = utilities::random_sample_density(size, 0, sigma_m, max_w);
uv_data.units = utilities::vis_units::radians;
utilities::write_visibility(uv_data, vis_file, true);
// utilities::write_visibility(uv_data, vis_file, true);
}
return uv_data;
}
Expand Down