Skip to content

Add ONNX benchmark #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cmake_files/dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ else()
endif()

find_package(CFitsIO REQUIRED)
find_package(yaml-cpp REQUIRED)

if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.30.0")
cmake_policy(SET CMP0167 NEW)
endif()
find_package(Boost COMPONENTS system filesystem REQUIRED)

find_package(yaml-cpp REQUIRED)

find_package(sopt REQUIRED)
set(PURIFY_ONNXRT FALSE)
if (onnxrt)
Expand Down
77 changes: 65 additions & 12 deletions cpp/benchmarks/algorithms_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,36 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbOnnxDistributeImage)(benchmark::State &stat

m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
factory::algo_distribution::mpi_serial, m_measurements_distribute_image, wavelets, m_uv_data,
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true,
false, 1e-3, 1e-2, 50, tf_model_path, nondiff_func_type::Denoiser);
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true, false,
1e-3, 1e-2, 50, tf_model_path, nondiff_func_type::Denoiser);

// Benchmark the application of the algorithm
while (state.KeepRunning()) {
auto start = std::chrono::high_resolution_clock::now();
auto result = (*m_fb)();
auto end = std::chrono::high_resolution_clock::now();
std::cout << "Converged? " << result.good << " , niters = " << result.niters << std::endl;
state.SetIterationTime(b_utilities::duration(start, end, m_world));
}
}

BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbOnnxDistributeGrid)(benchmark::State &state) {
// Create the algorithm - has to be done there to reset the internal state.
// If done in the fixture repeats would start at the solution and converge immediately.

// TODO: Wavelets are constructed but not used in the factory method
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
factory::distributed_wavelet_operator::serial, m_sara, m_imsizey, m_imsizex);

t_real const beta = m_sigma * m_sigma;
t_real const gamma = 0.0001;

std::string tf_model_path = purify::models_directory() + "/snr_15_model_dynamic.onnx";

m_fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
factory::algo_distribution::mpi_serial, m_measurements_distribute_grid, wavelets, m_uv_data,
m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3), true, true, false,
1e-3, 1e-2, 50, tf_model_path, nondiff_func_type::Denoiser);
Comment on lines +216 to +219
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the measurement operator normalisation calculated anywhere? It doesn't appear to be in AlgoFixtureMPI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No? I'd assume it just uses the default 1.0


// Benchmark the application of the algorithm
while (state.KeepRunning()) {
Expand All @@ -205,23 +233,42 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbOnnxDistributeImage)
->Args({128, 10000, 4, 10, 1})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 1})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 1})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 1})
->UseManualTime()
->MinTime(120.0)
->MinWarmUpTime(10.0)
->Repetitions(3) //->ReportAggregatesOnly(true)
->Unit(benchmark::kMillisecond);

BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbOnnxDistributeGrid)
//->Apply(b_utilities::Arguments)
->Args({128, 10000, 4, 10, 1})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 2})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 2})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 2})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 2})
->UseManualTime()
->MinTime(9.0)
->MinWarmUpTime(1.0)
->Repetitions(3) //->ReportAggregatesOnly(true)
->Unit(benchmark::kMillisecond);

#endif

BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbDistributeImage)
//->Apply(b_utilities::Arguments)
->Args({128, 10000, 4, 10, 1})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 1})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 1})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 1})
->UseManualTime()
->MinTime(120.0)
->MinWarmUpTime(10.0)
Expand All @@ -233,8 +280,10 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, FbDistributeGrid)
->Args({128, 10000, 4, 10, 2})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 2})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 2})
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 2})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 2})
->UseManualTime()
->MinTime(120.0)
->MinWarmUpTime(10.0)
Expand All @@ -246,8 +295,10 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, PadmmDistributeImage)
->Args({128, 10000, 4, 10, 1})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 1})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 1})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 1})
->UseManualTime()
->MinTime(120.0)
->MinWarmUpTime(10.0)
Expand All @@ -259,8 +310,10 @@ BENCHMARK_REGISTER_F(AlgoFixtureMPI, PadmmDistributeGrid)
->Args({128, 10000, 4, 10, 2})
->Args({1024, static_cast<t_int>(1e6), 4, 10, 2})
->Args({1024, static_cast<t_int>(1e7), 4, 10, 2})
->Args({1024, static_cast<t_int>(1e8), 4, 10, 1})
->Args({1024, static_cast<t_int>(1e9), 4, 10, 1})
->Args({2048, static_cast<t_int>(1e6), 4, 10, 2})
->Args({2048, static_cast<t_int>(1e7), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e6), 4, 10, 2})
->Args({4096, static_cast<t_int>(1e7), 4, 10, 2})
->UseManualTime()
->MinTime(120.0)
->MinWarmUpTime(10.0)
Expand Down
14 changes: 9 additions & 5 deletions cpp/benchmarks/utilities.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,28 @@ std::tuple<utilities::vis_params, t_real> dirty_measurements(
return std::make_tuple(uv_data, sigma);
}

utilities::vis_params random_measurements(t_int size, const t_real max_w, const t_int id) {
utilities::vis_params random_measurements(t_int size, const t_real max_w, const t_int id,
const bool cache_visibilities) {
utilities::vis_params uv_data;

std::stringstream filename;
filename << "random_" << size << "_";
filename << std::to_string(id) << ".vis";
std::string const vis_file = visibility_filename(filename.str());
std::ifstream vis_file_str(vis_file);

utilities::vis_params uv_data;
if (vis_file_str.good()) {
if (cache_visibilities and vis_file_str.good()) {
PURIFY_INFO("Reading random visibilities from file {}", vis_file);
uv_data = utilities::read_visibility(vis_file, true);
uv_data.units = utilities::vis_units::radians;
} else {
PURIFY_INFO("Generating random visibilities and writing to {}", vis_file);
PURIFY_INFO("Generating random visibilities");
t_real const sigma_m = constant::pi / 3;
uv_data = utilities::random_sample_density(size, 0, sigma_m, max_w);
uv_data.units = utilities::vis_units::radians;
utilities::write_visibility(uv_data, vis_file, true);
if (cache_visibilities) {
utilities::write_visibility(uv_data, vis_file, true);
}
}
return uv_data;
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/benchmarks/utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ std::tuple<utilities::vis_params, t_real> dirty_measurements(
Image<t_complex> const& ground_truth_image, t_uint number_of_vis, t_real snr,
const t_real& cellsize);

utilities::vis_params random_measurements(t_int size, const t_real max_w = 100, const t_int id = 0);
utilities::vis_params random_measurements(t_int size, const t_real max_w = 100, const t_int id = 0,
const bool cache_visibilities = false);
#ifdef PURIFY_MPI
double duration(std::chrono::high_resolution_clock::time_point start,
std::chrono::high_resolution_clock::time_point end,
Expand Down