-
Notifications
You must be signed in to change notification settings - Fork 98
Add OpenMP segmented prefix sum #1837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,113 @@ | ||||||||||||||||||||||||||||||||
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors | ||||||||||||||||||||||||||||||||
// | ||||||||||||||||||||||||||||||||
// SPDX-License-Identifier: BSD-3-Clause | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#ifndef GKO_OMP_COMPONENTS_PREFIX_SUM_HPP_ | ||||||||||||||||||||||||||||||||
#define GKO_OMP_COMPONENTS_PREFIX_SUM_HPP_ | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#include <algorithm> | ||||||||||||||||||||||||||||||||
#include <iterator> | ||||||||||||||||||||||||||||||||
#include <limits> | ||||||||||||||||||||||||||||||||
#include <string> | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#include <omp.h> | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#include "core/base/allocator.hpp" | ||||||||||||||||||||||||||||||||
#include "core/base/iterator_factory.hpp" | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
namespace gko { | ||||||||||||||||||||||||||||||||
namespace kernels { | ||||||||||||||||||||||||||||||||
namespace omp { | ||||||||||||||||||||||||||||||||
namespace components { | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
/* | ||||||||||||||||||||||||||||||||
* Similar to prefix_sum, only reduces within runs of the same key value (each | ||||||||||||||||||||||||||||||||
* key run must only occur once, otherwise the scan operation is not necessarily | ||||||||||||||||||||||||||||||||
* associaive). It also doesn't ignore the last value! | ||||||||||||||||||||||||||||||||
* Similar to thrust::exclusive_scan_by_key | ||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||
template <typename KeyIterator, typename Iterator, | ||||||||||||||||||||||||||||||||
typename ScanOp = | ||||||||||||||||||||||||||||||||
std::plus<typename std::iterator_traits<Iterator>::value_type>> | ||||||||||||||||||||||||||||||||
void segmented_prefix_sum( | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. segmented_scan_by_key? |
||||||||||||||||||||||||||||||||
std::shared_ptr<const OmpExecutor> exec, KeyIterator key, Iterator it, | ||||||||||||||||||||||||||||||||
const size_type num_entries, | ||||||||||||||||||||||||||||||||
typename std::iterator_traits<KeyIterator>::value_type key_init = {}, | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it for the situation when KeyIterator is float or double? |
||||||||||||||||||||||||||||||||
typename std::iterator_traits<Iterator>::value_type init = {}, | ||||||||||||||||||||||||||||||||
ScanOp op = {}) | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As it call prefix_sum, we do not need the |
||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||
using key_type = typename std::iterator_traits<KeyIterator>::value_type; | ||||||||||||||||||||||||||||||||
using value_type = typename std::iterator_traits<Iterator>::value_type; | ||||||||||||||||||||||||||||||||
// the operation only makes sense for arrays of size at least 2 | ||||||||||||||||||||||||||||||||
if (num_entries < 2) { | ||||||||||||||||||||||||||||||||
if (num_entries == 0) { | ||||||||||||||||||||||||||||||||
return; | ||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||
*it = init; | ||||||||||||||||||||||||||||||||
return; | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
Comment on lines
+44
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
const int nthreads = omp_get_max_threads(); | ||||||||||||||||||||||||||||||||
vector<value_type> proc_sums(nthreads, init, {exec}); | ||||||||||||||||||||||||||||||||
vector<key_type> proc_first_key(nthreads, key_init, {exec}); | ||||||||||||||||||||||||||||||||
vector<key_type> proc_last_key(nthreads, key_init, {exec}); | ||||||||||||||||||||||||||||||||
const size_type def_num_witems = (num_entries - 1) / nthreads + 1; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#pragma omp parallel | ||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||
const int thread_id = omp_get_thread_num(); | ||||||||||||||||||||||||||||||||
const size_type startidx = thread_id * def_num_witems; | ||||||||||||||||||||||||||||||||
const size_type endidx = | ||||||||||||||||||||||||||||||||
std::min(num_entries, (thread_id + 1) * def_num_witems); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
auto partial_sum = init; | ||||||||||||||||||||||||||||||||
auto cur_key = startidx < num_entries ? key[startidx] : key_init; | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think key_init can be the last key. |
||||||||||||||||||||||||||||||||
proc_first_key[thread_id] = cur_key; | ||||||||||||||||||||||||||||||||
for (size_type i = startidx; i < endidx; ++i) { | ||||||||||||||||||||||||||||||||
auto value = it[i]; | ||||||||||||||||||||||||||||||||
auto new_key = key[i]; | ||||||||||||||||||||||||||||||||
if (cur_key != new_key) { | ||||||||||||||||||||||||||||||||
partial_sum = init; | ||||||||||||||||||||||||||||||||
cur_key = new_key; | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
it[i] = partial_sum; | ||||||||||||||||||||||||||||||||
partial_sum = op(partial_sum, value); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
proc_sums[thread_id] = partial_sum; | ||||||||||||||||||||||||||||||||
proc_last_key[thread_id] = cur_key; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#pragma omp barrier | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#pragma omp single | ||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||
for (int i = 0; i < nthreads - 1; i++) { | ||||||||||||||||||||||||||||||||
// the next block carries over the previous partial sum | ||||||||||||||||||||||||||||||||
// if it starts and ends with the same key as the next one | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
if (proc_last_key[i] == proc_first_key[i + 1] && | ||||||||||||||||||||||||||||||||
proc_first_key[i + 1] == proc_last_key[i + 1]) { | ||||||||||||||||||||||||||||||||
proc_sums[i + 1] = op(proc_sums[i], proc_sums[i + 1]); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if (thread_id > 0) { | ||||||||||||||||||||||||||||||||
for (size_type i = startidx; i < endidx; i++) { | ||||||||||||||||||||||||||||||||
if (key[i] == proc_last_key[thread_id - 1]) { | ||||||||||||||||||||||||||||||||
it[i] = op(it[i], proc_sums[thread_id - 1]); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
} // namespace components | ||||||||||||||||||||||||||||||||
} // namespace omp | ||||||||||||||||||||||||||||||||
} // namespace kernels | ||||||||||||||||||||||||||||||||
} // namespace gko | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
#endif // GKO_OMP_COMPONENTS_PREFIX_SUM_HPP_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
ginkgo_create_omp_test(prefix_sum) |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,74 @@ | ||||||||||
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors | ||||||||||
// | ||||||||||
// SPDX-License-Identifier: BSD-3-Clause | ||||||||||
|
||||||||||
#include "omp/components/prefix_sum.hpp" | ||||||||||
|
||||||||||
#include <algorithm> | ||||||||||
#include <iterator> | ||||||||||
#include <limits> | ||||||||||
#include <memory> | ||||||||||
#include <random> | ||||||||||
#include <type_traits> | ||||||||||
#include <vector> | ||||||||||
|
||||||||||
#include <gtest/gtest.h> | ||||||||||
|
||||||||||
#include <ginkgo/core/base/executor.hpp> | ||||||||||
|
||||||||||
#include "core/base/index_range.hpp" | ||||||||||
#include "core/test/utils.hpp" | ||||||||||
|
||||||||||
|
||||||||||
template <typename T> | ||||||||||
class PrefixSum : public ::testing::Test { | ||||||||||
protected: | ||||||||||
using index_type = T; | ||||||||||
|
||||||||||
PrefixSum() : exec{gko::OmpExecutor::create()}, rand(293) {} | ||||||||||
|
||||||||||
std::shared_ptr<const gko::OmpExecutor> exec; | ||||||||||
std::default_random_engine rand; | ||||||||||
gko::size_type total_size; | ||||||||||
}; | ||||||||||
|
||||||||||
TYPED_TEST_SUITE(PrefixSum, gko::test::IndexTypes, TypenameNameGenerator); | ||||||||||
|
||||||||||
|
||||||||||
TYPED_TEST(PrefixSum, SegmentedPrefixSumWorks) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another test for checking the init is not 0 |
||||||||||
{ | ||||||||||
using index_type = typename TestFixture::index_type; | ||||||||||
const auto max_threads = omp_get_max_threads(); | ||||||||||
for (int num_threads = 1; num_threads <= max_threads; num_threads++) { | ||||||||||
SCOPED_TRACE(num_threads); | ||||||||||
omp_set_num_threads(num_threads); | ||||||||||
for (int num_ranges : {10, 100, 1000}) { | ||||||||||
SCOPED_TRACE(num_ranges); | ||||||||||
// repeate multiple times for different random seeds | ||||||||||
for (int repetition : gko::irange{10}) { | ||||||||||
std::uniform_int_distribution<int> count_dist{0, 100}; | ||||||||||
std::uniform_int_distribution<index_type> value_dist{-200, 200}; | ||||||||||
std::vector<index_type> ref_result; | ||||||||||
std::vector<int> keys; | ||||||||||
std::vector<index_type> input; | ||||||||||
for (int i = 0; i < num_ranges; i++) { | ||||||||||
const auto start = keys.size(); | ||||||||||
const auto new_count = count_dist(this->rand); | ||||||||||
keys.insert(keys.end(), new_count, i); | ||||||||||
std::generate_n(std::back_inserter(input), new_count, | ||||||||||
[&] { return value_dist(this->rand); }); | ||||||||||
std::copy(input.begin() + start, input.end(), | ||||||||||
std::back_inserter(ref_result)); | ||||||||||
std::exclusive_scan( | ||||||||||
ref_result.begin() + start, ref_result.end(), | ||||||||||
ref_result.begin() + start, index_type{}); | ||||||||||
} | ||||||||||
|
||||||||||
gko::kernels::omp::components::segmented_prefix_sum( | ||||||||||
this->exec, keys.cbegin(), input.begin(), keys.size()); | ||||||||||
Comment on lines
+67
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
ASSERT_EQ(input, ref_result); | ||||||||||
} | ||||||||||
} | ||||||||||
} | ||||||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not get it. if it does not ignore the last value, should it be the inclusive one?