!43129 [assistant][ops] add new gpu operator Ormqr

Merge pull request !43129 from GP/Ormqr
This commit is contained in:
i-robot 2022-11-18 02:15:27 +00:00 committed by Gitee
commit a84f309598
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 833 additions and 0 deletions

View File

@ -0,0 +1,364 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/ormqr_gpu_kernel.h"
#include <complex>
#include <vector>
#include <map>
#include <utility>
#include "abstract/utils.h"
#include "kernel/common_utils.h"
#include "include/common/utils/convert_utils.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/real_to_complex_impl.cuh"
namespace mindspore {
namespace kernel {
bool OrmqrGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::Ormqr>(base_operator);
kernel_name_ = kernel_ptr->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [ "
<< "float32, float64, complex64, complex128], but got: " << kernel_attr << ".";
return false;
}
launch_kernel_func_ = func_list_[index].second.first;
init_lists_func_ = func_list_[index].second.second;
left_ = kernel_ptr->get_left();
transpose_ = kernel_ptr->get_transpose();
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
return true;
}
int OrmqrGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
ResetResource();
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
input_x_shape_ = std::vector<size_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
input_tau_shape_ = std::vector<size_t>(inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex1)->GetDeviceShapeAdaptively().end());
input_other_shape_ = std::vector<size_t>(inputs.at(kIndex2)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex2)->GetDeviceShapeAdaptively().end());
input_x_dims_ = input_x_shape_.size();
input_tau_dims_ = input_tau_shape_.size();
input_other_dims_ = input_other_shape_.size();
is_null_input_ = (input_x_dims_ == 0 || input_tau_dims_ == 0 || input_other_dims_ == 0);
if (is_null_input_) {
init_lists_func_(this);
return 0;
}
batch_size_ = 1;
for (size_t i = 0; i < input_x_dims_ - kDim2; i++) {
if (input_x_shape_[i] != input_tau_shape_[i]) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', x and tau should share the same batch size, but x.shape[" << i
<< "] is " << input_x_shape_[i] << ", and tau.shape[" << i << "] is " << input_tau_shape_[i];
return KRET_RESIZE_FAILED;
}
batch_size_ = batch_size_ * input_x_shape_[i];
}
for (size_t i = 0; i < input_x_dims_ - kDim2; i++) {
if (input_x_shape_[i] != input_other_shape_[i]) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', x and other should share the same batch size, but x.shape[" << i
<< "] is " << input_x_shape_[i] << ", and other.shape[" << i << "] is " << input_other_shape_[i];
return KRET_RESIZE_FAILED;
}
}
side_ = left_ ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
trans_ = transpose_ ? CUBLAS_OP_T : CUBLAS_OP_N;
m_ = input_other_shape_[input_other_dims_ - kDim2], n_ = input_other_shape_[input_other_dims_ - kDim1];
x_m_ = input_x_shape_[input_x_dims_ - kDim2], x_n_ = input_x_shape_[input_x_dims_ - kDim1];
tau_n_ = input_tau_shape_[input_tau_dims_ - kDim1];
bool check_inputs = CheckInputs();
if (!check_inputs) {
return KRET_RESIZE_FAILED;
}
for (size_t i = 0; i < input_x_dims_; ++i) {
transpose_input_x_shape_[i] = input_x_shape_[i];
transpose_input_other_shape_[i] = input_other_shape_[i];
if (i == input_x_dims_ - kDim2) {
transpose_input_x_axis_[i] = input_x_dims_ - kDim1;
transpose_output_y_shape_[i] = input_other_shape_[input_other_dims_ - kDim1];
} else if (i == input_x_dims_ - kDim1) {
transpose_input_x_axis_[i] = input_x_dims_ - kDim2;
transpose_output_y_shape_[i] = input_other_shape_[input_other_dims_ - kDim2];
} else {
transpose_input_x_axis_[i] = i;
transpose_output_y_shape_[i] = input_other_shape_[i];
}
}
init_lists_func_(this);
return 0;
}
bool OrmqrGpuKernelMod::CheckInputs() {
if (input_x_dims_ < kDim2) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dimensions of x must be greater than or equal to 2"
<< ", but got [" << input_x_dims_ << "].";
return false;
}
if (input_other_dims_ < kDim2) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dimensions of other must be greater than or equal to 2"
<< ", but got [" << input_other_dims_ << "].";
return false;
}
if (input_x_dims_ - kDim1 != input_tau_dims_) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', tau should have one dimension less than x"
<< ", but rank of x is" << input_x_dims_ << " and rank of tau is " << input_tau_dims_ << ".";
return false;
}
if (input_x_dims_ != input_other_dims_) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', tau should have same dimensions with x"
<< ", but rank of x is" << input_x_dims_ << " and rank of tau is " << input_other_dims_ << ".";
return false;
}
if (left_) {
if (m_ != x_m_) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', other.shape[-2] must be equal to x.shape[-2]"
<< ", but x.shape[-2] is " << x_m_ << " and other.shape[-2] is " << m_ << ".";
return false;
}
if (m_ < tau_n_) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', other.shape[-2] must be greater than or equal to "
<< "tau.shape[-1], but other.shape[-2] is " << m_ << " and tau.shape[-1] is " << tau_n_;
return false;
}
} else {
if (n_ != x_m_) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', other.shape[-1] must be equal to x.shape[-2]"
<< ", but x.shape[-2] is " << x_m_ << " and other.shape[-1] is " << n_ << ".";
return false;
}
if (n_ < tau_n_) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', other.shape[-1] must >= tau.shape[-1],"
<< " but other.shape[-1] is " << n_ << " and tau.shape[-1] is " << tau_n_ << ".";
return false;
}
}
return true;
}
template <typename T>
void OrmqrGpuKernelMod::InitSizeLists() {
// input x, tau, other
input_size_list_.push_back(batch_size_ * x_m_ * x_n_ * sizeof(T));
input_size_list_.push_back(batch_size_ * tau_n_ * sizeof(T));
input_size_list_.push_back(batch_size_ * m_ * n_ * sizeof(T));
// output y
output_size_list_.push_back(batch_size_ * m_ * n_ * sizeof(T));
workspace_size_list_.push_back(batch_size_ * sizeof(int));
// for transpose input x and output y
workspace_size_list_.push_back(input_x_dims_ * sizeof(size_t));
workspace_size_list_.push_back(input_x_dims_ * sizeof(size_t));
workspace_size_list_.push_back(input_other_dims_ * sizeof(size_t));
workspace_size_list_.push_back(batch_size_ * x_m_ * x_n_ * sizeof(T));
workspace_size_list_.push_back(batch_size_ * m_ * n_ * sizeof(T));
workspace_size_list_.push_back(batch_size_ * m_ * n_ * sizeof(T));
workspace_size_list_.push_back(input_x_dims_ * sizeof(size_t));
}
template <typename T>
void OrmqrGpuKernelMod::RunOrmqr(T *d_a, T *tau, T *d_c, size_t lda, int *dev_info, T *output_y) {
int lwork = 0;
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnSormqr_bufferSize(handle_, side_, trans_, m_, n_, x_n_, d_a, lda, tau, d_c, m_, &lwork),
"cusolver query ormqr work size failed.");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnDormqr_bufferSize(handle_, side_, trans_, m_, n_, x_n_, d_a, lda, tau, d_c, m_, &lwork),
"cusolver query ormqr work size failed.");
} else {
if constexpr (std::is_same_v<T, Complex<float>>) {
trans_ = transpose_ ? CUBLAS_OP_C : CUBLAS_OP_N;
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnCunmqr_bufferSize(handle_, side_, trans_, m_, n_, x_n_, reinterpret_cast<cuComplex *>(d_a), lda,
reinterpret_cast<cuComplex *>(tau), reinterpret_cast<cuComplex *>(d_c), m_, &lwork),
"cusolver query ormqr work size failed.");
}
if constexpr (std::is_same_v<T, Complex<double>>) {
trans_ = transpose_ ? CUBLAS_OP_C : CUBLAS_OP_N;
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnZunmqr_bufferSize(handle_, side_, trans_, m_, n_, x_n_, reinterpret_cast<cuDoubleComplex *>(d_a), lda,
reinterpret_cast<cuDoubleComplex *>(tau), reinterpret_cast<cuDoubleComplex *>(d_c),
m_, &lwork),
"cusolver query ormqr work size failed.");
}
}
void *d_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(sizeof(T) * lwork);
if (d_work == nullptr) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the memory of d_work alloc failed.";
}
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnSormqr(handle_, side_, trans_, m_, n_, x_n_, d_a, lda, tau, d_c,
m_, static_cast<T *>(d_work), lwork, dev_info),
"cusolver ormqr failed.");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnDormqr(handle_, side_, trans_, m_, n_, tau_n_, d_a, lda, tau, d_c,
m_, static_cast<T *>(d_work), lwork, dev_info),
"cusolver ormqr failed.");
} else {
if constexpr (std::is_same_v<T, Complex<float>>) {
trans_ = transpose_ ? CUBLAS_OP_C : CUBLAS_OP_N;
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnCunmqr(handle_, side_, trans_, m_, n_, x_n_, reinterpret_cast<cuComplex *>(d_a), lda,
reinterpret_cast<cuComplex *>(tau), reinterpret_cast<cuComplex *>(d_c), m_,
reinterpret_cast<cuComplex *>(d_work), lwork, dev_info),
"cusolver ormqr failed.");
}
if constexpr (std::is_same_v<T, Complex<double>>) {
trans_ = transpose_ ? CUBLAS_OP_C : CUBLAS_OP_N;
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnZunmqr(handle_, side_, trans_, m_, n_, x_n_, reinterpret_cast<cuDoubleComplex *>(d_a), lda,
reinterpret_cast<cuDoubleComplex *>(tau), reinterpret_cast<cuDoubleComplex *>(d_c), m_,
reinterpret_cast<cuDoubleComplex *>(d_work), lwork, dev_info),
"cusolver ormqr failed.");
}
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(output_y, d_c, sizeof(T) * m_ * n_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy output A failed!");
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work);
}
void OrmqrGpuKernelMod::CheckResult(int *dev_info) {
std::vector<int> info_gpu(batch_size_, 0);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(info_gpu.data(), dev_info, sizeof(int) * batch_size_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"Copy device result failed");
for (size_t i = 0; i < info_gpu.size(); ++i) {
if (info_gpu[i] != 0) {
MS_LOG(INFO) << "For '" << kernel_name_ << "', the compute result has wrong value. The " << -info_gpu[i]
<< "th parameter is wrong (not counting handle) in batch " << i << " data.";
}
}
}
template <typename T>
void OrmqrGpuKernelMod::LaunchOrmqr(T *d_input_x, T *input_tau, T *d_input_other, T *d_output_y, int *dev_info) {
size_t lda = m_;
if (side_ == CUBLAS_SIDE_RIGHT) {
lda = n_;
}
for (size_t batch = 0; batch < batch_size_; ++batch) {
RunOrmqr(d_input_x + batch * x_m_ * x_n_, input_tau + batch * tau_n_, d_input_other + batch * m_ * n_, lda,
dev_info + batch, d_output_y + batch * m_ * n_);
}
CheckResult(dev_info);
}
template <typename T>
bool OrmqrGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
if (is_null_input_) {
return true;
}
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"CusolverDnSetStream failed");
T *input_x = GetDeviceAddress<T>(inputs, kIndex0);
T *input_tau = GetDeviceAddress<T>(inputs, kIndex1);
T *input_other = GetDeviceAddress<T>(inputs, kIndex2);
T *output_y = GetDeviceAddress<T>(outputs, kIndex0);
int *dev_info = GetDeviceAddress<int>(workspace, kIndex0);
size_t *d_trans_input_x_shape = GetDeviceAddress<size_t>(workspace, kIndex1);
size_t *d_trans_input_x_axis = GetDeviceAddress<size_t>(workspace, kIndex2);
size_t *d_trans_input_other_shape = GetDeviceAddress<size_t>(workspace, kIndex3);
T *d_input_x = GetDeviceAddress<T>(workspace, kIndex4);
T *d_input_other = GetDeviceAddress<T>(workspace, kIndex5);
T *d_output_y = GetDeviceAddress<T>(workspace, kIndex6);
size_t *d_trans_output_y_shape = GetDeviceAddress<size_t>(workspace, kIndex7);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_trans_input_x_axis, transpose_input_x_axis_, sizeof(size_t) * input_x_dims_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_trans_input_x_shape, transpose_input_x_shape_, sizeof(size_t) * input_x_dims_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
CalTranspose(batch_size_ * x_m_ * x_n_, input_x, d_trans_input_x_shape, d_trans_input_x_axis, input_x_dims_,
d_input_x, reinterpret_cast<cudaStream_t>(cuda_stream_));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_trans_input_other_shape, transpose_input_other_shape_, sizeof(size_t) * input_other_dims_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcpy failed!");
CalTranspose(batch_size_ * m_ * n_, input_other, d_trans_input_other_shape, d_trans_input_x_axis, input_other_dims_,
d_input_other, reinterpret_cast<cudaStream_t>(cuda_stream_));
LaunchOrmqr(d_input_x, input_tau, d_input_other, d_output_y, dev_info);
cudaMemcpyAsync(d_trans_output_y_shape, transpose_output_y_shape_, sizeof(size_t) * input_other_dims_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_));
CalTranspose(batch_size_ * m_ * n_, d_output_y, d_trans_output_y_shape, d_trans_input_x_axis, input_other_dims_,
output_y, reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
std::vector<std::pair<KernelAttr, std::pair<OrmqrGpuKernelMod::LaunchKernelFunc, OrmqrGpuKernelMod::InitSizeListsFunc>>>
OrmqrGpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
{&OrmqrGpuKernelMod::LaunchKernel<float>, &OrmqrGpuKernelMod::InitSizeLists<float>}},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
{&OrmqrGpuKernelMod::LaunchKernel<double>, &OrmqrGpuKernelMod::InitSizeLists<double>}},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
{&OrmqrGpuKernelMod::LaunchKernel<Complex<float>>, &OrmqrGpuKernelMod::InitSizeLists<Complex<float>>}},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
{&OrmqrGpuKernelMod::LaunchKernel<Complex<double>>, &OrmqrGpuKernelMod::InitSizeLists<Complex<double>>}},
};
std::vector<KernelAttr> OrmqrGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, std::pair<LaunchKernelFunc, InitSizeListsFunc>> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Ormqr, OrmqrGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,117 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_ORMQR_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_ORMQR_GPU_KERNEL_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cusolverDn.h>
#include <cuda_runtime.h>
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include <functional>
#include <map>
#include <utility>
#include "mindspore/core/ops/ormqr.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
class OrmqrGpuKernelMod : public NativeGpuKernelMod {
public:
OrmqrGpuKernelMod() { ResetResource(); }
~OrmqrGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = stream_ptr;
return launch_kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
void ResetResource() noexcept {
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
std::vector<KernelAttr> GetOpSupport() override;
protected:
template <typename T>
void InitSizeLists();
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool CheckInputs();
template <typename T>
void RunOrmqr(T *d_a, T *tau, T *other, size_t lda, int *dev_info, T *d_output_y);
template <typename T>
void LaunchOrmqr(T *d_input_x, T *input_tau, T *d_input_other, T *d_output_y, int *dev_info);
void CheckResult(int *dev_info);
using LaunchKernelFunc =
std::function<bool(OrmqrGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
using InitSizeListsFunc = std::function<void(OrmqrGpuKernelMod *)>;
LaunchKernelFunc launch_kernel_func_{nullptr};
InitSizeListsFunc init_lists_func_{nullptr};
static std::vector<std::pair<KernelAttr, std::pair<LaunchKernelFunc, InitSizeListsFunc>>> func_list_;
bool left_{true};
bool transpose_{false};
std::vector<size_t> input_x_shape_;
std::vector<size_t> input_tau_shape_;
std::vector<size_t> input_other_shape_;
size_t input_x_dims_{0};
size_t input_tau_dims_{0};
size_t input_other_dims_{0};
size_t m_{0};
size_t n_{0};
size_t tau_n_{0};
size_t x_m_{0};
size_t x_n_{0};
size_t batch_size_{0};
bool is_null_input_;
size_t transpose_input_x_shape_[TRANSPOSE_MAX_DIMENSION] = {0};
size_t transpose_input_x_axis_[TRANSPOSE_MAX_DIMENSION] = {0};
size_t transpose_input_other_shape_[TRANSPOSE_MAX_DIMENSION] = {0};
size_t transpose_output_y_shape_[TRANSPOSE_MAX_DIMENSION] = {0};
cusolverDnHandle_t handle_{nullptr};
cublasSideMode_t side_;
cublasOperation_t trans_;
void *cuda_stream_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_ORMQR_GPU_KERNEL_H_

View File

@ -1342,6 +1342,7 @@ GVAR_DEF(PrimitivePtr, kPrimQr, std::make_shared<Primitive>("Qr"));
GVAR_DEF(PrimitivePtr, kPrimMatrixLogarithm, std::make_shared<Primitive>(kMatrixLogarithm));
GVAR_DEF(PrimitivePtr, kPrimMatrixTriangularSolve, std::make_shared<Primitive>(kMatrixTriangularSolve));
GVAR_DEF(PrimitivePtr, kPrimSelfAdjointEig, std::make_shared<Primitive>("SelfAdjointEig"));
GVAR_DEF(PrimitivePtr, kPrimOrmqr, std::make_shared<Primitive>("Ormqr"));
// linalg
GVAR_DEF(PrimitivePtr, kPrimGeqrf, std::make_shared<Primitive>("Geqrf"));

139
mindspore/core/ops/ormqr.cc Normal file
View File

@ -0,0 +1,139 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <vector>
#include <memory>
#include <map>
#include <string>
#include "ops/ormqr.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr OrmqrInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
const int64_t kInputNoBatch = 2;
const size_t kRowIndex = 2;
const size_t kColIndex = 1;
const size_t kTwo = 2;
auto left = GetValue<bool>(primitive->GetAttr(kAttrLeft));
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto tau_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto other_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
if (IsDynamicRank(x_shape) || IsDynamic(x_shape) || IsDynamicRank(tau_shape) || IsDynamic(tau_shape) ||
IsDynamicRank(other_shape) || IsDynamic(other_shape)) {
return std::make_shared<abstract::Shape>(other_shape);
}
auto x_rank = x_shape.size();
auto tau_rank = tau_shape.size();
auto other_rank = other_shape.size();
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(x_rank), kGreaterEqual, kTwo, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("other_rank", SizeToLong(other_rank), kGreaterEqual, kTwo,
primitive->name());
if ((x_rank - kColIndex) != tau_rank) {
MS_EXCEPTION(ValueError) << "For Ormqr, tau should have one dimension less than x"
<< ", while rank of x is" << x_shape.size() << " and "
<< "rank of tau is " << tau_shape.size() << ".";
}
if (x_rank != other_rank) {
MS_EXCEPTION(ValueError) << "For Ormqr, other should have same dimension with x"
<< ", while rank of x is" << x_shape.size() << " and "
<< "rank of other is " << other_shape.size() << ".";
}
if (x_shape.size() > kInputNoBatch) {
for (size_t i = 0; i < x_rank - kRowIndex; i++) {
if (x_shape[i] != tau_shape[i]) {
MS_EXCEPTION(ValueError) << "For Ormqr, tau.shape[:-2] must be equal to x.shape[:-2], but x.shape[" << i
<< "] is " << x_shape[i] << ",and tau.shape[" << i << "] is " << tau_shape[i] << ".";
}
if (x_shape[i] != other_shape[i]) {
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[:-2] must be equal to x.shape[:-2], but x.shape[" << i
<< "] is " << x_shape[i] << ",and other.shape[" << i << "] is " << other_shape[i]
<< ".";
}
}
}
if (left) {
if (*(other_shape.end() - kRowIndex) < *(tau_shape.end() - kColIndex)) {
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[-2] must be greater than or equal to tau.shape[-1]"
<< ", while other.shape[-2] is " << other_shape[other_rank - kRowIndex] << " and "
<< "tau.shape[-1] is " << tau_shape[tau_rank - kColIndex] << ".";
}
if (*(x_shape.end() - kRowIndex) != *(other_shape.end() - kRowIndex)) {
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[-2] must be equal to x.shape[-2]"
<< ", while x.shape[-2] is " << x_shape[x_rank - kRowIndex] << " and "
<< "other.shape[-2] is " << other_shape[other_rank - kRowIndex] << ".";
}
} else {
if (*(other_shape.end() - kColIndex) < *(tau_shape.end() - kColIndex)) {
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[-1] must be greater than or equal to tau.shape[-1]"
<< ", while other.shape[-1] is " << other_shape[other_rank - kColIndex] << " and "
<< "tau.shape[-1] is " << tau_shape[tau_rank - kColIndex] << ".";
}
if (*(x_shape.end() - kRowIndex) != *(other_shape.end() - kColIndex)) {
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[-1] must be equal to x.shape[-2]"
<< ", while x.shape[-2] is " << x_shape[x_rank - kRowIndex] << " and "
<< "other.shape[-1] is " << other_shape[other_rank - kColIndex] << ".";
}
}
return std::make_shared<abstract::Shape>(other_shape);
}
TypePtr OrmqrInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat32, kFloat64, kComplex64, kComplex128};
std::map<std::string, TypePtr> types;
auto x_type = input_args[0]->BuildType();
auto tau_type = input_args[kInputIndex1]->BuildType();
auto other_type = input_args[kInputIndex2]->BuildType();
(void)types.emplace("x", x_type);
(void)types.emplace("tau", tau_type);
(void)types.emplace("other", other_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return x_type;
}
} // namespace
void Ormqr::Init(const bool left, const bool transpose) {
set_left(left);
set_transpose(transpose);
}
void Ormqr::set_left(const bool left) { (void)this->AddAttr(kAttrLeft, api::MakeValue(left)); }
void Ormqr::set_transpose(const bool transpose) { (void)this->AddAttr(kAttrTranspose, api::MakeValue(transpose)); }
bool Ormqr::get_left() const { return GetValue<bool>(GetAttr(kAttrLeft)); }
bool Ormqr::get_transpose() const { return GetValue<bool>(GetAttr(kAttrTranspose)); }
MIND_API_OPERATOR_IMPL(Ormqr, BaseOperator);
AbstractBasePtr OrmqrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = OrmqrInferType(primitive, input_args);
auto infer_shape = OrmqrInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Ormqr, prim::kPrimOrmqr, OrmqrInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ORMQR_H_
#define MINDSPORE_CORE_OPS_ORMQR_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameOrmqr = "Ormqr";
constexpr auto kAttrLeft = "left";
constexpr auto kAttrTranspose = "transpose";
/// \brief Computes the matrix-matrix multiplication of Householder matrices with a general matrix.
class MIND_API Ormqr : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Ormqr);
/// \brief Constructor.
Ormqr() : BaseOperator(kNameOrmqr) { InitIOName({"x", "tau", "other"}, {"y"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ormqr for the inputs.
void Init(const bool left = true, const bool transpose = false);
/// \brief Set axis.
void set_left(const bool left);
/// \brief Set output_type.
void set_transpose(const bool transpose);
/// \brief Get left.
///
/// \return left.
bool get_left() const;
/// \brief Get transpose.
///
/// \return transpose.
bool get_transpose() const;
};
abstract::AbstractBasePtr OrmqrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimOrmqrPtr = std::shared_ptr<Ormqr>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ORMQR_H_

View File

@ -7756,3 +7756,61 @@ class Cauchy(Primitive):
validator.check_value_type('size', size, (list), self.name)
for index, size_ in enumerate(size):
validator.check_positive_int(size_, 'size[%d]' % index, self.name)
class Ormqr(Primitive):
r"""
Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix.
Multiplies a(m, n) matrix C (given by other) with a matrix Q, where Q is represented using Householder
reflectors (x, tau), which is the output of torch.geqrf().
Args:
left (bool, optional): controls the order of multiplication. If true, compute op(Q)*C.
If false, compute C*op(Q). Default: True.
transpose(bool, optional): controls whether the matrix Q is conjugate transposed or not.Default: False.
Inputs:
- **x** (Tensor) - Tensor of shape: (*, mn, k) where mn equals to m or n depending on the left.
with float32, float64, complex64 and complex128 data type.
- **tau** (Tensor) - Tensor of shape (*, min(mn, k)) which have the same type as x.
- **other** (Tensor) - tensor of shape (*, m, n) where * is zero or more batch dimensions.
Outputs:
- **y** (Tensor) - the output Tensor.
Raises:
TypeError: If `x` or `tau` or `other` is not Tensor.
TypeError: If dtype of `x` or `tau` or `other` is not one of: float64, float32, complex64, complex128.
ValueError: If `x` or `other` is less than 2D.
ValueError: If rank(x) - rank(tau) != 1.
ValueError: If tau.shape[:-2] != x.shape[:-2]
ValueError: If other.shape[:-2] != x.shape[:-2]
ValueError: If left == true, other.shape[-2] < tau.shape[-1].
ValueError: If left == true, other.shape[-2] != x.shape[-2].
ValueError: If left == false, other.shape[-1] < tau.shape[-1].
ValueError: If left == false, other.shape[-1] != x.shape[-2].
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor(np.array([[-114.6, 10.9, 1.1], [-0.304, 38.07, 69.38], [-0.45, -0.17, 62]]), mindspore.float32)
>>> tau = Tensor(np.array([1.55, 1.94, 3.0]), mindspore.float32)
>>> other = Tensor(np.array([[-114.6, 10.9, 1.1],
[-0.304, 38.07, 69.38],
[-0.45, -0.17, 62]]), mindspore.float32)
>>> net = ops.Ormqr()
>>> y = net(x, tau, other)
>>> print(y)
[[ 63.82713 -13.823125 -116.28614 ]
[ -53.659264 -28.157839 -70.42702 ]
[ -79.54292 24.00183 -41.34253 ]]
"""
@prim_attr_register
def __init__(self, left=True, transpose=False):
"""Initialize Ormqr"""
self.init_prim_io_names(inputs=['x', 'tau', 'other'], outputs=['y'])
self.left = validator.check_value_type('left', left, [bool], self.name)
self.transpose = validator.check_value_type('transpose', transpose, [bool], self.name)
self.add_prim_attr('left', self.left)
self.add_prim_attr('transpose', self.transpose)

View File

@ -0,0 +1,97 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import nn, Tensor
import mindspore.context as context
from mindspore.ops.operations.math_ops import Ormqr
class OrmqrNet(nn.Cell):
def __init__(self, left=True, transpose=False):
super(OrmqrNet, self).__init__()
self.ormqr = Ormqr(left=left, transpose=transpose)
def construct(self, x, tau, other):
return self.ormqr(x, tau, other)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ormqr_rank2_right_double_fp():
"""
Feature: Ormqr operator.
Description: test cases for Ormqr: left=False, transpose=False.
Expectation: the result match expectation.
"""
x_np = np.array([[-114.6, 10.9, 1.1],
[-0.304, 38.07, 69.38],
[-0.45, -0.17, 62]]).astype(np.float64)
tau_np = np.array([15.5862, 10.6579]).astype(np.float64)
other_np = np.array([[15.5862, 10.6579, 63.8084],
[0.1885, -10.0553, 4.4496]]).astype(np.float64)
expect = np.array([[270.6946003, 553.6791758, -156.4879992],
[-19.1850094, 64.0901946, 1.5641681]]).astype(np.float64)
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = OrmqrNet(False, False)
output_gr = net(Tensor(x_np), Tensor(tau_np), Tensor(other_np)).asnumpy()
assert np.allclose(expect, output_gr, rtol=1.e-5, atol=1.e-5)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_py = net(Tensor(x_np), Tensor(tau_np), Tensor(other_np)).asnumpy()
assert np.allclose(expect, output_py, rtol=1.e-5, atol=1.e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ormqr_rank3_left_double_fp():
"""
Feature: Ormqr operator.
Description: test cases for Ormqr: left=True, transpose=False.
Expectation: the result match expectation.
"""
x_np = np.array([[[1.1090, -1.4204],
[11.4252, -3.1697],
[-0.5425, -0.1447]],
[[7.3681, -0.0566],
[2.8972, 5.1619],
[3.3822, 0.5040]]]).astype(np.float64)
tau_np = np.array([[15.5862, 10.6579], [0.1885, -10.0553]]).astype(np.float64)
other_np = np.array([[[0.8128, 0.6689],
[0.8259, 0.0635],
[-8.0096, -0.1519]],
[[10.6672, 1.0428],
[6.7381, 3.4068],
[0.3646, 6.7011]]]).astype(np.float64)
expect = np.array([[[3566.3712760, 140.9990169],
[40716.8898503, 1602.4521151],
[-1939.2639809, -76.1491614]],
[[-55.6311772, -64.4607712],
[-115.7401958, -118.1534389],
[-188.7906847, -180.4638580]]]).astype(np.float64)
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = OrmqrNet()
output_gr = net(Tensor(x_np), Tensor(tau_np), Tensor(other_np)).asnumpy()
assert np.allclose(expect, output_gr, rtol=1.e-5, atol=1.e-5)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_py = net(Tensor(x_np), Tensor(tau_np), Tensor(other_np)).asnumpy()
assert np.allclose(expect, output_py, rtol=1.e-5, atol=1.e-5)