Skip to content

WIP: Multivector interface #1889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ target_sources(
matrix/fft.cpp
matrix/hybrid.cpp
matrix/identity.cpp
matrix/multivector.cpp
matrix/permutation.cpp
matrix/row_gatherer.cpp
matrix/scaled_permutation.cpp
Expand Down
313 changes: 313 additions & 0 deletions core/matrix/multivector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include <ginkgo/core/matrix/dense.hpp>
#include <ginkgo/core/matrix/multivector.hpp>

namespace gko {
namespace matrix {


MultiVector::MultiVector(std::shared_ptr<const Executor> exec,
const dim<2>& size)
: EnableAbstractPolymorphicObject<MultiVector, LinOp>(std::move(exec), size)
{}


std::unique_ptr<MultiVector> MultiVector::create_with_config_of(
ptr_param<const MultiVector> other)
{
return other->create_generic_with_same_config_impl();
}


std::unique_ptr<MultiVector> MultiVector::create_with_type_of(
ptr_param<const MultiVector> other, std::shared_ptr<const Executor> exec)
{
return other->create_generic_with_type_of_impl(std::move(exec), {}, {}, 0);
}


std::unique_ptr<MultiVector> MultiVector::create_with_type_of(
ptr_param<const MultiVector> other, std::shared_ptr<const Executor> exec,
const dim<2>& global_size, const dim<2>& local_size)
{
GKO_ASSERT_EQUAL_COLS(global_size, local_size);
return other->create_generic_with_type_of_impl(std::move(exec), global_size,
local_size, global_size[1]);
}


std::unique_ptr<MultiVector> MultiVector::create_with_type_of(
ptr_param<const MultiVector> other, std::shared_ptr<const Executor> exec,
const dim<2>& global_size, const dim<2>& local_size, size_type stride)
{
return other->create_generic_with_type_of_impl(std::move(exec), global_size,
local_size, stride);
}


std::unique_ptr<MultiVector> MultiVector::compute_absolute() const
{
return this->compute_absolute_generic_impl();
}


void MultiVector::compute_absolute_inplace()
{
this->compute_absolute_inplace_impl();
}


std::unique_ptr<MultiVector> MultiVector::make_complex() const
{
return this->make_complex_generic_impl();
}


void MultiVector::make_complex(ptr_param<MultiVector> result) const
{
this->make_complex_impl(result.get());
}


std::unique_ptr<MultiVector> MultiVector::get_real() const
{
return this->get_real_generic_impl();
}


void MultiVector::get_real(ptr_param<MultiVector> result) const
{
this->get_real_impl(result.get());
}


std::unique_ptr<MultiVector> MultiVector::get_imag() const
{
return this->get_imag_generic_impl();
}


void MultiVector::get_imag(ptr_param<MultiVector> result) const
{
this->get_imag_impl(result.get());
}


void MultiVector::fill(syn::variant_from_tuple<supported_value_types> value)
{
this->fill_impl(value);
}


void MultiVector::scale(any_const_dense_t alpha) { this->scale_impl(alpha); }


void MultiVector::inv_scale(any_const_dense_t alpha)
{
this->inv_scale_impl(alpha);
}


void MultiVector::add_scaled(any_const_dense_t alpha,
ptr_param<const MultiVector> b)
{
this->add_scaled_impl(alpha, b.get());
}


void MultiVector::sub_scaled(any_const_dense_t alpha,
ptr_param<const MultiVector> b)
{
this->sub_scaled_impl(alpha, b.get());
}


void MultiVector::compute_dot(ptr_param<const MultiVector> b,
ptr_param<MultiVector> result) const
{
this->compute_dot_impl(b.get(), result.get());
}


void MultiVector::compute_dot(ptr_param<const MultiVector> b,
ptr_param<MultiVector> result,
array<char>& tmp) const
{
this->compute_dot_impl(b.get(), result.get(), tmp);
}


void MultiVector::compute_conj_dot(ptr_param<const MultiVector> b,
ptr_param<MultiVector> result) const
{
this->compute_conj_dot_impl(b.get(), result.get());
}


void MultiVector::compute_conj_dot(ptr_param<const MultiVector> b,
ptr_param<MultiVector> result,
array<char>& tmp) const
{
this->compute_conj_dot_impl(b.get(), result.get(), tmp);
}


void MultiVector::compute_norm2(ptr_param<MultiVector> result) const
{
this->compute_norm2_impl(result.get());
}


void MultiVector::compute_norm2(ptr_param<MultiVector> result,
array<char>& tmp) const
{
this->compute_norm2_impl(result.get(), tmp);
}


void MultiVector::compute_squared_norm2(ptr_param<MultiVector> result) const
{
this->compute_squared_norm2_impl(result.get());
}


void MultiVector::compute_squared_norm2(ptr_param<MultiVector> result,
array<char>& tmp) const
{
this->compute_squared_norm2_impl(result.get(), tmp);
}


void MultiVector::compute_norm1(ptr_param<MultiVector> result) const
{
this->compute_norm1_impl(result.get());
}


void MultiVector::compute_norm1(ptr_param<MultiVector> result,
array<char>& tmp) const
{
this->compute_norm1_impl(result.get(), tmp);
}


std::unique_ptr<const MultiVector> MultiVector::create_real_view() const
{
return this->create_real_view_generic_impl();
}


std::unique_ptr<MultiVector> MultiVector::create_real_view()
{
return this->create_real_view_generic_impl();
}


std::unique_ptr<MultiVector> MultiVector::create_subview(local_span rows,
local_span columns)
{
return this->create_subview_generic_impl(rows, columns);
}


std::unique_ptr<const MultiVector> MultiVector::create_subview(
local_span rows, local_span columns) const
{
return this->create_subview_generic_impl(rows, columns);
}


std::unique_ptr<const MultiVector> MultiVector::create_subview(
local_span rows, local_span columns, size_type global_rows,
size_type globals_cols) const
{
return this->create_subview_generic_impl(rows, columns, global_rows,
globals_cols);
}


std::unique_ptr<MultiVector> MultiVector::create_subview(local_span rows,
local_span columns,
size_type global_rows,
size_type globals_cols)
{
return this->create_subview_generic_impl(rows, columns, global_rows,
globals_cols);
}


dim<2> MultiVector::get_size() const noexcept { return LinOp::get_size(); }


size_type MultiVector::get_stride() const noexcept { return get_stride_impl(); }


void MultiVector::set_size(const dim<2>& size) noexcept
{
LinOp::set_size(size);
}


template <typename ValueType>
std::unique_ptr<Dense<ValueType>> MultiVector::create_local_view()
{
using return_type = std::unique_ptr<Dense<ValueType>>;
auto variant = this->create_local_view_impl(ValueType());
if (!std::holds_alternative<return_type>(variant)) {
GKO_INVALID_STATE("Unexpected type of local view");
}
return std::move(std::get<return_type>(variant));
}

#define GKO_DECLARE_MULTIVECTOR_CREATE_LOCAL_VIEW(_type) \
std::unique_ptr<Dense<_type>> MultiVector::create_local_view()
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MULTIVECTOR_CREATE_LOCAL_VIEW);


template <typename ValueType>
std::unique_ptr<const Dense<ValueType>> MultiVector::create_local_view() const
{
using return_type = std::unique_ptr<const Dense<ValueType>>;
auto variant = this->create_local_view_impl(ValueType());
if (!std::holds_alternative<return_type>(variant)) {
GKO_INVALID_STATE("Unexpected type of local view");
}
return std::move(std::get<return_type>(variant));
}

#define GKO_DECLARE_MULTIVECTOR_CREATE_LOCAL_VIEW_CONST(_type) \
std::unique_ptr<const Dense<_type>> MultiVector::create_local_view() const
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_DECLARE_MULTIVECTOR_CREATE_LOCAL_VIEW_CONST);


template <typename ValueType>
auto MultiVector::temporary_precision()
-> std::unique_ptr<MultiVector, std::function<void(MultiVector*)>>
{
return temporary_precision_impl(ValueType());
}

#define GKO_DECLARE_MULTIVECTOR_AS_PRECISION(_type) \
std::unique_ptr<MultiVector, std::function<void(MultiVector*)>> \
MultiVector::temporary_precision<_type>()
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MULTIVECTOR_AS_PRECISION);


template <typename ValueType>
std::unique_ptr<const MultiVector> MultiVector::temporary_precision() const
{
return temporary_precision_impl(ValueType());
}

#define GKO_DECLARE_MULTIVECTOR_AS_PRECISION_CONST(_type) \
std::unique_ptr<const MultiVector> \
MultiVector::temporary_precision<_type>() const
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MULTIVECTOR_AS_PRECISION_CONST);


} // namespace matrix
} // namespace gko
Loading
Loading