Skip to content

Fix ParILUT threshold select for larger matrices with GPU backends #1877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/cuda_hip/factorization/par_ilut_select_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void threshold_select(std::shared_ptr<const DefaultExecutor> exec,
partial_counts, total_counts);
auto new_bucket = sampleselect_find_bucket(exec, total_counts, rank);
sampleselect_filter(exec, tmp_in, bucket.size, oracles, partial_counts,
bucket.idx, tmp_out);
new_bucket.idx, tmp_out);

rank -= new_bucket.begin;
bucket.size = new_bucket.size;
Expand Down
1 change: 1 addition & 0 deletions contributors.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ license, as specified in the repository's LICENSE file.

Aliaga José I. <aliaga@uji.es> Universitat Jaume I
Anzt Hartwig <hartwig.anzt@kit.edu> Karlsruhe Institute of Technology, The University of Tennessee Knoxville
Beams Natalie <nbeams@icl.utk.edu> University of Tennessee, Knoxville
Boman Erik <egboman@sandia.gov> Sandia National Laboratories
Castelli Fabian <fabian.castelli@kit.edu> Karlsruhe Institute of Technology
Chen Yenchen <yanjen224@gmail.com> National Taiwan University
Expand Down
4 changes: 2 additions & 2 deletions dpcpp/factorization/par_ilut_select_kernel.dp.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -122,7 +122,7 @@ void threshold_select(std::shared_ptr<const DefaultExecutor> exec,
partial_counts, total_counts);
auto new_bucket = sampleselect_find_bucket(exec, total_counts, rank);
sampleselect_filter(exec, tmp_in, bucket.size, oracles, partial_counts,
bucket.idx, tmp_out);
new_bucket.idx, tmp_out);

rank -= new_bucket.begin;
bucket.size = new_bucket.size;
Expand Down
10 changes: 8 additions & 2 deletions test/factorization/par_ilut_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class ParIlut : public CommonTestFixture {
mtx_size[0], true,
std::uniform_int_distribution<index_type>(1, mtx_size[0]),
std::normal_distribution<>(0.0, 1.0), rand_engine, ref);
mtx_l3 = gko::test::generate_random_lower_triangular_matrix<Csr>(
800, false, std::uniform_int_distribution<index_type>(256, 256),
std::normal_distribution<>(-0.667, 3.333), rand_engine, ref);
mtx_u = gko::test::generate_random_upper_triangular_matrix<Csr>(
mtx_size[0], false,
std::uniform_int_distribution<index_type>(10, mtx_size[0]),
Expand All @@ -79,6 +82,7 @@ class ParIlut : public CommonTestFixture {
dmtx_ut_ani = Csr::create(exec);
dmtx_l = gko::clone(exec, mtx_l);
dmtx_l2 = gko::clone(exec, mtx_l2);
dmtx_l3 = gko::clone(exec, mtx_l3);
dmtx_u = gko::clone(exec, mtx_u);

std::string file_name(gko::matrices::location_ani4_mtx);
Expand Down Expand Up @@ -219,6 +223,7 @@ class ParIlut : public CommonTestFixture {
std::unique_ptr<Csr> mtx_ut_ani;
std::unique_ptr<Csr> mtx_l;
std::unique_ptr<Csr> mtx_l2;
std::unique_ptr<Csr> mtx_l3;
std::unique_ptr<Csr> mtx_u;

std::unique_ptr<Csr> dmtx1;
Expand All @@ -230,6 +235,7 @@ class ParIlut : public CommonTestFixture {
std::unique_ptr<Csr> dmtx_ut_ani;
std::unique_ptr<Csr> dmtx_l;
std::unique_ptr<Csr> dmtx_l2;
std::unique_ptr<Csr> dmtx_l3;
std::unique_ptr<Csr> dmtx_u;
};

Expand All @@ -246,8 +252,8 @@ TYPED_TEST(ParIlut, KernelThresholdSelectIsEquivalentToRef)
SKIP_IF_BFLOAT16(value_type);
#endif

this->test_select(this->mtx_l, this->dmtx_l,
this->mtx_l->get_num_stored_elements() / 3);
this->test_select(this->mtx_l3, this->dmtx_l3,
this->mtx_l3->get_num_stored_elements() / 3);
}


Expand Down
Loading