forked from mindspore-Ecosystem/mindspore
!43129 [assistant][ops] add new gpu operator Ormqr
Merge pull request !43129 from GP/Ormqr
This commit is contained in:
commit
a84f309598
|
@ -0,0 +1,364 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/ormqr_gpu_kernel.h"
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "abstract/utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/real_to_complex_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool OrmqrGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Ormqr>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [ "
|
||||
<< "float32, float64, complex64, complex128], but got: " << kernel_attr << ".";
|
||||
return false;
|
||||
}
|
||||
launch_kernel_func_ = func_list_[index].second.first;
|
||||
init_lists_func_ = func_list_[index].second.second;
|
||||
left_ = kernel_ptr->get_left();
|
||||
transpose_ = kernel_ptr->get_transpose();
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
|
||||
return true;
|
||||
}
|
||||
|
||||
int OrmqrGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
ResetResource();
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
input_x_shape_ = std::vector<size_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
input_tau_shape_ = std::vector<size_t>(inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex1)->GetDeviceShapeAdaptively().end());
|
||||
input_other_shape_ = std::vector<size_t>(inputs.at(kIndex2)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex2)->GetDeviceShapeAdaptively().end());
|
||||
input_x_dims_ = input_x_shape_.size();
|
||||
input_tau_dims_ = input_tau_shape_.size();
|
||||
input_other_dims_ = input_other_shape_.size();
|
||||
is_null_input_ = (input_x_dims_ == 0 || input_tau_dims_ == 0 || input_other_dims_ == 0);
|
||||
if (is_null_input_) {
|
||||
init_lists_func_(this);
|
||||
return 0;
|
||||
}
|
||||
batch_size_ = 1;
|
||||
for (size_t i = 0; i < input_x_dims_ - kDim2; i++) {
|
||||
if (input_x_shape_[i] != input_tau_shape_[i]) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', x and tau should share the same batch size, but x.shape[" << i
|
||||
<< "] is " << input_x_shape_[i] << ", and tau.shape[" << i << "] is " << input_tau_shape_[i];
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
batch_size_ = batch_size_ * input_x_shape_[i];
|
||||
}
|
||||
for (size_t i = 0; i < input_x_dims_ - kDim2; i++) {
|
||||
if (input_x_shape_[i] != input_other_shape_[i]) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', x and other should share the same batch size, but x.shape[" << i
|
||||
<< "] is " << input_x_shape_[i] << ", and other.shape[" << i << "] is " << input_other_shape_[i];
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
}
|
||||
side_ = left_ ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
|
||||
trans_ = transpose_ ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
m_ = input_other_shape_[input_other_dims_ - kDim2], n_ = input_other_shape_[input_other_dims_ - kDim1];
|
||||
x_m_ = input_x_shape_[input_x_dims_ - kDim2], x_n_ = input_x_shape_[input_x_dims_ - kDim1];
|
||||
tau_n_ = input_tau_shape_[input_tau_dims_ - kDim1];
|
||||
bool check_inputs = CheckInputs();
|
||||
if (!check_inputs) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_x_dims_; ++i) {
|
||||
transpose_input_x_shape_[i] = input_x_shape_[i];
|
||||
transpose_input_other_shape_[i] = input_other_shape_[i];
|
||||
if (i == input_x_dims_ - kDim2) {
|
||||
transpose_input_x_axis_[i] = input_x_dims_ - kDim1;
|
||||
transpose_output_y_shape_[i] = input_other_shape_[input_other_dims_ - kDim1];
|
||||
} else if (i == input_x_dims_ - kDim1) {
|
||||
transpose_input_x_axis_[i] = input_x_dims_ - kDim2;
|
||||
transpose_output_y_shape_[i] = input_other_shape_[input_other_dims_ - kDim2];
|
||||
} else {
|
||||
transpose_input_x_axis_[i] = i;
|
||||
transpose_output_y_shape_[i] = input_other_shape_[i];
|
||||
}
|
||||
}
|
||||
init_lists_func_(this);
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool OrmqrGpuKernelMod::CheckInputs() {
|
||||
if (input_x_dims_ < kDim2) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dimensions of x must be greater than or equal to 2"
|
||||
<< ", but got [" << input_x_dims_ << "].";
|
||||
return false;
|
||||
}
|
||||
if (input_other_dims_ < kDim2) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', dimensions of other must be greater than or equal to 2"
|
||||
<< ", but got [" << input_other_dims_ << "].";
|
||||
return false;
|
||||
}
|
||||
if (input_x_dims_ - kDim1 != input_tau_dims_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', tau should have one dimension less than x"
|
||||
<< ", but rank of x is" << input_x_dims_ << " and rank of tau is " << input_tau_dims_ << ".";
|
||||
return false;
|
||||
}
|
||||
if (input_x_dims_ != input_other_dims_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', tau should have same dimensions with x"
|
||||
<< ", but rank of x is" << input_x_dims_ << " and rank of tau is " << input_other_dims_ << ".";
|
||||
return false;
|
||||
}
|
||||
if (left_) {
|
||||
if (m_ != x_m_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', other.shape[-2] must be equal to x.shape[-2]"
|
||||
<< ", but x.shape[-2] is " << x_m_ << " and other.shape[-2] is " << m_ << ".";
|
||||
return false;
|
||||
}
|
||||
if (m_ < tau_n_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', other.shape[-2] must be greater than or equal to "
|
||||
<< "tau.shape[-1], but other.shape[-2] is " << m_ << " and tau.shape[-1] is " << tau_n_;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (n_ != x_m_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', other.shape[-1] must be equal to x.shape[-2]"
|
||||
<< ", but x.shape[-2] is " << x_m_ << " and other.shape[-1] is " << n_ << ".";
|
||||
return false;
|
||||
}
|
||||
if (n_ < tau_n_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', other.shape[-1] must >= tau.shape[-1],"
|
||||
<< " but other.shape[-1] is " << n_ << " and tau.shape[-1] is " << tau_n_ << ".";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void OrmqrGpuKernelMod::InitSizeLists() {
|
||||
// input x, tau, other
|
||||
input_size_list_.push_back(batch_size_ * x_m_ * x_n_ * sizeof(T));
|
||||
input_size_list_.push_back(batch_size_ * tau_n_ * sizeof(T));
|
||||
input_size_list_.push_back(batch_size_ * m_ * n_ * sizeof(T));
|
||||
// output y
|
||||
output_size_list_.push_back(batch_size_ * m_ * n_ * sizeof(T));
|
||||
workspace_size_list_.push_back(batch_size_ * sizeof(int));
|
||||
// for transpose input x and output y
|
||||
workspace_size_list_.push_back(input_x_dims_ * sizeof(size_t));
|
||||
workspace_size_list_.push_back(input_x_dims_ * sizeof(size_t));
|
||||
workspace_size_list_.push_back(input_other_dims_ * sizeof(size_t));
|
||||
workspace_size_list_.push_back(batch_size_ * x_m_ * x_n_ * sizeof(T));
|
||||
workspace_size_list_.push_back(batch_size_ * m_ * n_ * sizeof(T));
|
||||
workspace_size_list_.push_back(batch_size_ * m_ * n_ * sizeof(T));
|
||||
workspace_size_list_.push_back(input_x_dims_ * sizeof(size_t));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void OrmqrGpuKernelMod::RunOrmqr(T *d_a, T *tau, T *d_c, size_t lda, int *dev_info, T *output_y) {
|
||||
int lwork = 0;
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
|
||||
cusolverDnSormqr_bufferSize(handle_, side_, trans_, m_, n_, x_n_, d_a, lda, tau, d_c, m_, &lwork),
|
||||
"cusolver query ormqr work size failed.");
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
|
||||
cusolverDnDormqr_bufferSize(handle_, side_, trans_, m_, n_, x_n_, d_a, lda, tau, d_c, m_, &lwork),
|
||||
"cusolver query ormqr work size failed.");
|
||||
} else {
|
||||
if constexpr (std::is_same_v<T, Complex<float>>) {
|
||||
trans_ = transpose_ ? CUBLAS_OP_C : CUBLAS_OP_N;
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
|
||||
cusolverDnCunmqr_bufferSize(handle_, side_, trans_, m_, n_, x_n_, reinterpret_cast<cuComplex *>(d_a), lda,
|
||||
reinterpret_cast<cuComplex *>(tau), reinterpret_cast<cuComplex *>(d_c), m_, &lwork),
|
||||
"cusolver query ormqr work size failed.");
|
||||
}
|
||||
if constexpr (std::is_same_v<T, Complex<double>>) {
|
||||
trans_ = transpose_ ? CUBLAS_OP_C : CUBLAS_OP_N;
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
|
||||
cusolverDnZunmqr_bufferSize(handle_, side_, trans_, m_, n_, x_n_, reinterpret_cast<cuDoubleComplex *>(d_a), lda,
|
||||
reinterpret_cast<cuDoubleComplex *>(tau), reinterpret_cast<cuDoubleComplex *>(d_c),
|
||||
m_, &lwork),
|
||||
"cusolver query ormqr work size failed.");
|
||||
}
|
||||
}
|
||||
|
||||
void *d_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(sizeof(T) * lwork);
|
||||
if (d_work == nullptr) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the memory of d_work alloc failed.";
|
||||
}
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnSormqr(handle_, side_, trans_, m_, n_, x_n_, d_a, lda, tau, d_c,
|
||||
m_, static_cast<T *>(d_work), lwork, dev_info),
|
||||
"cusolver ormqr failed.");
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnDormqr(handle_, side_, trans_, m_, n_, tau_n_, d_a, lda, tau, d_c,
|
||||
m_, static_cast<T *>(d_work), lwork, dev_info),
|
||||
"cusolver ormqr failed.");
|
||||
} else {
|
||||
if constexpr (std::is_same_v<T, Complex<float>>) {
|
||||
trans_ = transpose_ ? CUBLAS_OP_C : CUBLAS_OP_N;
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
|
||||
cusolverDnCunmqr(handle_, side_, trans_, m_, n_, x_n_, reinterpret_cast<cuComplex *>(d_a), lda,
|
||||
reinterpret_cast<cuComplex *>(tau), reinterpret_cast<cuComplex *>(d_c), m_,
|
||||
reinterpret_cast<cuComplex *>(d_work), lwork, dev_info),
|
||||
"cusolver ormqr failed.");
|
||||
}
|
||||
if constexpr (std::is_same_v<T, Complex<double>>) {
|
||||
trans_ = transpose_ ? CUBLAS_OP_C : CUBLAS_OP_N;
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
|
||||
cusolverDnZunmqr(handle_, side_, trans_, m_, n_, x_n_, reinterpret_cast<cuDoubleComplex *>(d_a), lda,
|
||||
reinterpret_cast<cuDoubleComplex *>(tau), reinterpret_cast<cuDoubleComplex *>(d_c), m_,
|
||||
reinterpret_cast<cuDoubleComplex *>(d_work), lwork, dev_info),
|
||||
"cusolver ormqr failed.");
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(output_y, d_c, sizeof(T) * m_ * n_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"cuda memcpy output A failed!");
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work);
|
||||
}
|
||||
|
||||
void OrmqrGpuKernelMod::CheckResult(int *dev_info) {
|
||||
std::vector<int> info_gpu(batch_size_, 0);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(info_gpu.data(), dev_info, sizeof(int) * batch_size_, cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"Copy device result failed");
|
||||
for (size_t i = 0; i < info_gpu.size(); ++i) {
|
||||
if (info_gpu[i] != 0) {
|
||||
MS_LOG(INFO) << "For '" << kernel_name_ << "', the compute result has wrong value. The " << -info_gpu[i]
|
||||
<< "th parameter is wrong (not counting handle) in batch " << i << " data.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void OrmqrGpuKernelMod::LaunchOrmqr(T *d_input_x, T *input_tau, T *d_input_other, T *d_output_y, int *dev_info) {
|
||||
size_t lda = m_;
|
||||
if (side_ == CUBLAS_SIDE_RIGHT) {
|
||||
lda = n_;
|
||||
}
|
||||
for (size_t batch = 0; batch < batch_size_; ++batch) {
|
||||
RunOrmqr(d_input_x + batch * x_m_ * x_n_, input_tau + batch * tau_n_, d_input_other + batch * m_ * n_, lda,
|
||||
dev_info + batch, d_output_y + batch * m_ * n_);
|
||||
}
|
||||
|
||||
CheckResult(dev_info);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool OrmqrGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"CusolverDnSetStream failed");
|
||||
|
||||
T *input_x = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *input_tau = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
T *input_other = GetDeviceAddress<T>(inputs, kIndex2);
|
||||
T *output_y = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
|
||||
int *dev_info = GetDeviceAddress<int>(workspace, kIndex0);
|
||||
size_t *d_trans_input_x_shape = GetDeviceAddress<size_t>(workspace, kIndex1);
|
||||
size_t *d_trans_input_x_axis = GetDeviceAddress<size_t>(workspace, kIndex2);
|
||||
size_t *d_trans_input_other_shape = GetDeviceAddress<size_t>(workspace, kIndex3);
|
||||
T *d_input_x = GetDeviceAddress<T>(workspace, kIndex4);
|
||||
T *d_input_other = GetDeviceAddress<T>(workspace, kIndex5);
|
||||
T *d_output_y = GetDeviceAddress<T>(workspace, kIndex6);
|
||||
size_t *d_trans_output_y_shape = GetDeviceAddress<size_t>(workspace, kIndex7);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(d_trans_input_x_axis, transpose_input_x_axis_, sizeof(size_t) * input_x_dims_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"cuda memcpy failed!");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(d_trans_input_x_shape, transpose_input_x_shape_, sizeof(size_t) * input_x_dims_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"cuda memcpy failed!");
|
||||
CalTranspose(batch_size_ * x_m_ * x_n_, input_x, d_trans_input_x_shape, d_trans_input_x_axis, input_x_dims_,
|
||||
d_input_x, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(d_trans_input_other_shape, transpose_input_other_shape_, sizeof(size_t) * input_other_dims_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"cuda memcpy failed!");
|
||||
CalTranspose(batch_size_ * m_ * n_, input_other, d_trans_input_other_shape, d_trans_input_x_axis, input_other_dims_,
|
||||
d_input_other, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
LaunchOrmqr(d_input_x, input_tau, d_input_other, d_output_y, dev_info);
|
||||
cudaMemcpyAsync(d_trans_output_y_shape, transpose_output_y_shape_, sizeof(size_t) * input_other_dims_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
CalTranspose(batch_size_ * m_ * n_, d_output_y, d_trans_output_y_shape, d_trans_input_x_axis, input_other_dims_,
|
||||
output_y, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, std::pair<OrmqrGpuKernelMod::LaunchKernelFunc, OrmqrGpuKernelMod::InitSizeListsFunc>>>
|
||||
OrmqrGpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
{&OrmqrGpuKernelMod::LaunchKernel<float>, &OrmqrGpuKernelMod::InitSizeLists<float>}},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
{&OrmqrGpuKernelMod::LaunchKernel<double>, &OrmqrGpuKernelMod::InitSizeLists<double>}},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
{&OrmqrGpuKernelMod::LaunchKernel<Complex<float>>, &OrmqrGpuKernelMod::InitSizeLists<Complex<float>>}},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
{&OrmqrGpuKernelMod::LaunchKernel<Complex<double>>, &OrmqrGpuKernelMod::InitSizeLists<Complex<double>>}},
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> OrmqrGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, std::pair<LaunchKernelFunc, InitSizeListsFunc>> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Ormqr, OrmqrGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,117 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_ORMQR_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_ORMQR_GPU_KERNEL_H_
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusolverDn.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "mindspore/core/ops/ormqr.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
class OrmqrGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
OrmqrGpuKernelMod() { ResetResource(); }
|
||||
~OrmqrGpuKernelMod() = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
cuda_stream_ = stream_ptr;
|
||||
return launch_kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
void ResetResource() noexcept {
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
protected:
|
||||
template <typename T>
|
||||
void InitSizeLists();
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool CheckInputs();
|
||||
template <typename T>
|
||||
void RunOrmqr(T *d_a, T *tau, T *other, size_t lda, int *dev_info, T *d_output_y);
|
||||
template <typename T>
|
||||
void LaunchOrmqr(T *d_input_x, T *input_tau, T *d_input_other, T *d_output_y, int *dev_info);
|
||||
void CheckResult(int *dev_info);
|
||||
|
||||
using LaunchKernelFunc =
|
||||
std::function<bool(OrmqrGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
using InitSizeListsFunc = std::function<void(OrmqrGpuKernelMod *)>;
|
||||
LaunchKernelFunc launch_kernel_func_{nullptr};
|
||||
InitSizeListsFunc init_lists_func_{nullptr};
|
||||
static std::vector<std::pair<KernelAttr, std::pair<LaunchKernelFunc, InitSizeListsFunc>>> func_list_;
|
||||
|
||||
bool left_{true};
|
||||
bool transpose_{false};
|
||||
std::vector<size_t> input_x_shape_;
|
||||
std::vector<size_t> input_tau_shape_;
|
||||
std::vector<size_t> input_other_shape_;
|
||||
size_t input_x_dims_{0};
|
||||
size_t input_tau_dims_{0};
|
||||
size_t input_other_dims_{0};
|
||||
size_t m_{0};
|
||||
size_t n_{0};
|
||||
size_t tau_n_{0};
|
||||
size_t x_m_{0};
|
||||
size_t x_n_{0};
|
||||
size_t batch_size_{0};
|
||||
bool is_null_input_;
|
||||
size_t transpose_input_x_shape_[TRANSPOSE_MAX_DIMENSION] = {0};
|
||||
size_t transpose_input_x_axis_[TRANSPOSE_MAX_DIMENSION] = {0};
|
||||
size_t transpose_input_other_shape_[TRANSPOSE_MAX_DIMENSION] = {0};
|
||||
size_t transpose_output_y_shape_[TRANSPOSE_MAX_DIMENSION] = {0};
|
||||
cusolverDnHandle_t handle_{nullptr};
|
||||
cublasSideMode_t side_;
|
||||
cublasOperation_t trans_;
|
||||
void *cuda_stream_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_ORMQR_GPU_KERNEL_H_
|
|
@ -1342,6 +1342,7 @@ GVAR_DEF(PrimitivePtr, kPrimQr, std::make_shared<Primitive>("Qr"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimMatrixLogarithm, std::make_shared<Primitive>(kMatrixLogarithm));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMatrixTriangularSolve, std::make_shared<Primitive>(kMatrixTriangularSolve));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSelfAdjointEig, std::make_shared<Primitive>("SelfAdjointEig"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimOrmqr, std::make_shared<Primitive>("Ormqr"));
|
||||
|
||||
// linalg
|
||||
GVAR_DEF(PrimitivePtr, kPrimGeqrf, std::make_shared<Primitive>("Geqrf"));
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "ops/ormqr.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr OrmqrInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t kInputNoBatch = 2;
|
||||
const size_t kRowIndex = 2;
|
||||
const size_t kColIndex = 1;
|
||||
const size_t kTwo = 2;
|
||||
auto left = GetValue<bool>(primitive->GetAttr(kAttrLeft));
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto tau_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto other_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
if (IsDynamicRank(x_shape) || IsDynamic(x_shape) || IsDynamicRank(tau_shape) || IsDynamic(tau_shape) ||
|
||||
IsDynamicRank(other_shape) || IsDynamic(other_shape)) {
|
||||
return std::make_shared<abstract::Shape>(other_shape);
|
||||
}
|
||||
auto x_rank = x_shape.size();
|
||||
auto tau_rank = tau_shape.size();
|
||||
auto other_rank = other_shape.size();
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(x_rank), kGreaterEqual, kTwo, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("other_rank", SizeToLong(other_rank), kGreaterEqual, kTwo,
|
||||
primitive->name());
|
||||
|
||||
if ((x_rank - kColIndex) != tau_rank) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, tau should have one dimension less than x"
|
||||
<< ", while rank of x is" << x_shape.size() << " and "
|
||||
<< "rank of tau is " << tau_shape.size() << ".";
|
||||
}
|
||||
if (x_rank != other_rank) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, other should have same dimension with x"
|
||||
<< ", while rank of x is" << x_shape.size() << " and "
|
||||
<< "rank of other is " << other_shape.size() << ".";
|
||||
}
|
||||
if (x_shape.size() > kInputNoBatch) {
|
||||
for (size_t i = 0; i < x_rank - kRowIndex; i++) {
|
||||
if (x_shape[i] != tau_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, tau.shape[:-2] must be equal to x.shape[:-2], but x.shape[" << i
|
||||
<< "] is " << x_shape[i] << ",and tau.shape[" << i << "] is " << tau_shape[i] << ".";
|
||||
}
|
||||
if (x_shape[i] != other_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[:-2] must be equal to x.shape[:-2], but x.shape[" << i
|
||||
<< "] is " << x_shape[i] << ",and other.shape[" << i << "] is " << other_shape[i]
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (left) {
|
||||
if (*(other_shape.end() - kRowIndex) < *(tau_shape.end() - kColIndex)) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[-2] must be greater than or equal to tau.shape[-1]"
|
||||
<< ", while other.shape[-2] is " << other_shape[other_rank - kRowIndex] << " and "
|
||||
<< "tau.shape[-1] is " << tau_shape[tau_rank - kColIndex] << ".";
|
||||
}
|
||||
if (*(x_shape.end() - kRowIndex) != *(other_shape.end() - kRowIndex)) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[-2] must be equal to x.shape[-2]"
|
||||
<< ", while x.shape[-2] is " << x_shape[x_rank - kRowIndex] << " and "
|
||||
<< "other.shape[-2] is " << other_shape[other_rank - kRowIndex] << ".";
|
||||
}
|
||||
} else {
|
||||
if (*(other_shape.end() - kColIndex) < *(tau_shape.end() - kColIndex)) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[-1] must be greater than or equal to tau.shape[-1]"
|
||||
<< ", while other.shape[-1] is " << other_shape[other_rank - kColIndex] << " and "
|
||||
<< "tau.shape[-1] is " << tau_shape[tau_rank - kColIndex] << ".";
|
||||
}
|
||||
if (*(x_shape.end() - kRowIndex) != *(other_shape.end() - kColIndex)) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[-1] must be equal to x.shape[-2]"
|
||||
<< ", while x.shape[-2] is " << x_shape[x_rank - kRowIndex] << " and "
|
||||
<< "other.shape[-1] is " << other_shape[other_rank - kColIndex] << ".";
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::Shape>(other_shape);
|
||||
}
|
||||
|
||||
TypePtr OrmqrInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
std::map<std::string, TypePtr> types;
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
auto tau_type = input_args[kInputIndex1]->BuildType();
|
||||
auto other_type = input_args[kInputIndex2]->BuildType();
|
||||
(void)types.emplace("x", x_type);
|
||||
(void)types.emplace("tau", tau_type);
|
||||
(void)types.emplace("other", other_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void Ormqr::Init(const bool left, const bool transpose) {
|
||||
set_left(left);
|
||||
set_transpose(transpose);
|
||||
}
|
||||
|
||||
void Ormqr::set_left(const bool left) { (void)this->AddAttr(kAttrLeft, api::MakeValue(left)); }
|
||||
|
||||
void Ormqr::set_transpose(const bool transpose) { (void)this->AddAttr(kAttrTranspose, api::MakeValue(transpose)); }
|
||||
|
||||
bool Ormqr::get_left() const { return GetValue<bool>(GetAttr(kAttrLeft)); }
|
||||
bool Ormqr::get_transpose() const { return GetValue<bool>(GetAttr(kAttrTranspose)); }
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Ormqr, BaseOperator);
|
||||
AbstractBasePtr OrmqrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = OrmqrInferType(primitive, input_args);
|
||||
auto infer_shape = OrmqrInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Ormqr, prim::kPrimOrmqr, OrmqrInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_ORMQR_H_
|
||||
#define MINDSPORE_CORE_OPS_ORMQR_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameOrmqr = "Ormqr";
|
||||
constexpr auto kAttrLeft = "left";
|
||||
constexpr auto kAttrTranspose = "transpose";
|
||||
/// \brief Computes the matrix-matrix multiplication of Householder matrices with a general matrix.
|
||||
class MIND_API Ormqr : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Ormqr);
|
||||
/// \brief Constructor.
|
||||
Ormqr() : BaseOperator(kNameOrmqr) { InitIOName({"x", "tau", "other"}, {"y"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ormqr for the inputs.
|
||||
void Init(const bool left = true, const bool transpose = false);
|
||||
/// \brief Set axis.
|
||||
void set_left(const bool left);
|
||||
/// \brief Set output_type.
|
||||
void set_transpose(const bool transpose);
|
||||
/// \brief Get left.
|
||||
///
|
||||
/// \return left.
|
||||
bool get_left() const;
|
||||
/// \brief Get transpose.
|
||||
///
|
||||
/// \return transpose.
|
||||
bool get_transpose() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr OrmqrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimOrmqrPtr = std::shared_ptr<Ormqr>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_ORMQR_H_
|
|
@ -7756,3 +7756,61 @@ class Cauchy(Primitive):
|
|||
validator.check_value_type('size', size, (list), self.name)
|
||||
for index, size_ in enumerate(size):
|
||||
validator.check_positive_int(size_, 'size[%d]' % index, self.name)
|
||||
|
||||
|
||||
class Ormqr(Primitive):
|
||||
r"""
|
||||
Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix.
|
||||
Multiplies a(m, n) matrix C (given by other) with a matrix Q, where Q is represented using Householder
|
||||
reflectors (x, tau), which is the output of torch.geqrf().
|
||||
|
||||
Args:
|
||||
left (bool, optional): controls the order of multiplication. If true, compute op(Q)*C.
|
||||
If false, compute C*op(Q). Default: True.
|
||||
transpose(bool, optional): controls whether the matrix Q is conjugate transposed or not.Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape: (*, mn, k) where mn equals to m or n depending on the left.
|
||||
with float32, float64, complex64 and complex128 data type.
|
||||
- **tau** (Tensor) - Tensor of shape (*, min(mn, k)) which have the same type as x.
|
||||
- **other** (Tensor) - tensor of shape (*, m, n) where * is zero or more batch dimensions.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - the output Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `tau` or `other` is not Tensor.
|
||||
TypeError: If dtype of `x` or `tau` or `other` is not one of: float64, float32, complex64, complex128.
|
||||
ValueError: If `x` or `other` is less than 2D.
|
||||
ValueError: If rank(x) - rank(tau) != 1.
|
||||
ValueError: If tau.shape[:-2] != x.shape[:-2]
|
||||
ValueError: If other.shape[:-2] != x.shape[:-2]
|
||||
ValueError: If left == true, other.shape[-2] < tau.shape[-1].
|
||||
ValueError: If left == true, other.shape[-2] != x.shape[-2].
|
||||
ValueError: If left == false, other.shape[-1] < tau.shape[-1].
|
||||
ValueError: If left == false, other.shape[-1] != x.shape[-2].
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[-114.6, 10.9, 1.1], [-0.304, 38.07, 69.38], [-0.45, -0.17, 62]]), mindspore.float32)
|
||||
>>> tau = Tensor(np.array([1.55, 1.94, 3.0]), mindspore.float32)
|
||||
>>> other = Tensor(np.array([[-114.6, 10.9, 1.1],
|
||||
[-0.304, 38.07, 69.38],
|
||||
[-0.45, -0.17, 62]]), mindspore.float32)
|
||||
>>> net = ops.Ormqr()
|
||||
>>> y = net(x, tau, other)
|
||||
>>> print(y)
|
||||
[[ 63.82713 -13.823125 -116.28614 ]
|
||||
[ -53.659264 -28.157839 -70.42702 ]
|
||||
[ -79.54292 24.00183 -41.34253 ]]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, left=True, transpose=False):
|
||||
"""Initialize Ormqr"""
|
||||
self.init_prim_io_names(inputs=['x', 'tau', 'other'], outputs=['y'])
|
||||
self.left = validator.check_value_type('left', left, [bool], self.name)
|
||||
self.transpose = validator.check_value_type('transpose', transpose, [bool], self.name)
|
||||
self.add_prim_attr('left', self.left)
|
||||
self.add_prim_attr('transpose', self.transpose)
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import nn, Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops.operations.math_ops import Ormqr
|
||||
|
||||
|
||||
class OrmqrNet(nn.Cell):
|
||||
def __init__(self, left=True, transpose=False):
|
||||
super(OrmqrNet, self).__init__()
|
||||
self.ormqr = Ormqr(left=left, transpose=transpose)
|
||||
|
||||
def construct(self, x, tau, other):
|
||||
return self.ormqr(x, tau, other)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ormqr_rank2_right_double_fp():
|
||||
"""
|
||||
Feature: Ormqr operator.
|
||||
Description: test cases for Ormqr: left=False, transpose=False.
|
||||
Expectation: the result match expectation.
|
||||
"""
|
||||
x_np = np.array([[-114.6, 10.9, 1.1],
|
||||
[-0.304, 38.07, 69.38],
|
||||
[-0.45, -0.17, 62]]).astype(np.float64)
|
||||
tau_np = np.array([15.5862, 10.6579]).astype(np.float64)
|
||||
other_np = np.array([[15.5862, 10.6579, 63.8084],
|
||||
[0.1885, -10.0553, 4.4496]]).astype(np.float64)
|
||||
expect = np.array([[270.6946003, 553.6791758, -156.4879992],
|
||||
[-19.1850094, 64.0901946, 1.5641681]]).astype(np.float64)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
net = OrmqrNet(False, False)
|
||||
output_gr = net(Tensor(x_np), Tensor(tau_np), Tensor(other_np)).asnumpy()
|
||||
assert np.allclose(expect, output_gr, rtol=1.e-5, atol=1.e-5)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
output_py = net(Tensor(x_np), Tensor(tau_np), Tensor(other_np)).asnumpy()
|
||||
assert np.allclose(expect, output_py, rtol=1.e-5, atol=1.e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ormqr_rank3_left_double_fp():
|
||||
"""
|
||||
Feature: Ormqr operator.
|
||||
Description: test cases for Ormqr: left=True, transpose=False.
|
||||
Expectation: the result match expectation.
|
||||
"""
|
||||
x_np = np.array([[[1.1090, -1.4204],
|
||||
[11.4252, -3.1697],
|
||||
[-0.5425, -0.1447]],
|
||||
|
||||
[[7.3681, -0.0566],
|
||||
[2.8972, 5.1619],
|
||||
[3.3822, 0.5040]]]).astype(np.float64)
|
||||
tau_np = np.array([[15.5862, 10.6579], [0.1885, -10.0553]]).astype(np.float64)
|
||||
other_np = np.array([[[0.8128, 0.6689],
|
||||
[0.8259, 0.0635],
|
||||
[-8.0096, -0.1519]],
|
||||
|
||||
[[10.6672, 1.0428],
|
||||
[6.7381, 3.4068],
|
||||
[0.3646, 6.7011]]]).astype(np.float64)
|
||||
expect = np.array([[[3566.3712760, 140.9990169],
|
||||
[40716.8898503, 1602.4521151],
|
||||
[-1939.2639809, -76.1491614]],
|
||||
|
||||
[[-55.6311772, -64.4607712],
|
||||
[-115.7401958, -118.1534389],
|
||||
[-188.7906847, -180.4638580]]]).astype(np.float64)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
net = OrmqrNet()
|
||||
output_gr = net(Tensor(x_np), Tensor(tau_np), Tensor(other_np)).asnumpy()
|
||||
assert np.allclose(expect, output_gr, rtol=1.e-5, atol=1.e-5)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
output_py = net(Tensor(x_np), Tensor(tau_np), Tensor(other_np)).asnumpy()
|
||||
assert np.allclose(expect, output_py, rtol=1.e-5, atol=1.e-5)
|
Loading…
Reference in New Issue