!40215 [assistant][ops] Add new gpu operator Lu

Merge pull request !40215 from LiuMingwu/Lu
This commit is contained in:
i-robot 2022-11-29 01:37:41 +00:00 committed by Gitee
commit ad88c0e1f6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 913 additions and 345 deletions

View File

@ -15,22 +15,315 @@
*/
#include "plugin/device/gpu/kernel/math/lu_gpu_kernel.h"
#include <iostream>
#include <functional>
#include <utility>
#include <string>
#include <algorithm>
#include "abstract/utils.h"
#include "kernel/common_utils.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/matrix_transpose_impl.cuh"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUGpuKernelMod, float)
bool LuGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
return false;
}
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
cublas_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
return true;
}
MS_REG_GPU_KERNEL_ONE(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUGpuKernelMod, double)
int LuGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
unit_size_ = abstract::TypeIdSize(inputs.at(kIndex0)->GetDtype());
auto in_shape = inputs.at(kIndex0)->GetShapeVector();
(void)std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(in_shape_), LongToSize);
if (!CheckLuShape()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input shape init failed.";
return KRET_RESIZE_FAILED;
}
// a device addr to place lu factor return code
workspace_size_list_.push_back(sizeof(int));
// transpose workspace
workspace_size_list_.push_back(batch_size_ * m_ * n_ * unit_size_);
workspace_size_list_.push_back(batch_size_ * n_ * sizeof(int));
// The workspace for device return info.
workspace_size_list_.push_back(batch_size_ * sizeof(void *));
workspace_size_list_.push_back(batch_size_ * sizeof(int));
return KRET_OK;
}
void LuGpuKernelMod::ResetResource() noexcept {
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
template <typename T>
void LuGpuKernelMod::BufferSize(T *batch_output_addr, int *lwork) {
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnSgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, lwork),
"cusolver query lu work size fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(cusolverDnDgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, lwork),
"cusolver query lu work size fail");
} else if constexpr (std::is_same_v<T, utils::Complex<float>>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnCgetrf_bufferSize(handle_, m_, n_, reinterpret_cast<cuComplex *>(batch_output_addr), lda_, lwork),
"cusolver query lu work size fail");
} else {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnZgetrf_bufferSize(handle_, m_, n_, reinterpret_cast<cuDoubleComplex *>(batch_output_addr), lda_, lwork),
"cusolver query lu work size fail");
}
}
template <typename T, typename S>
void LuGpuKernelMod::LaunchKernel_CuSolve(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cusolverDnSetStream failed");
T *batch_input_addr = GetDeviceAddress<T>(inputs, kDim0);
T *batch_output_addr = GetDeviceAddress<T>(outputs, kDim0);
T *d_work_ = nullptr;
S *batch_piv_output_addr = nullptr;
if (pivot_on_) {
batch_piv_output_addr = GetDeviceAddress<S>(outputs, kDim1);
}
int *info_output_addr = GetDeviceAddress<int>(workspace, kDim0);
T *dev_work = GetDeviceAddress<T>(workspace, kDim1);
int *dev_batch_piv = GetDeviceAddress<int>(workspace, kDim2);
// query working space of getrf
BufferSize(batch_output_addr, &lwork_);
// Transpose input data from rowMajor to colMajor.
MatrixTranspose(batch_input_addr, SizeToInt(input_elements_), m_, m_, dev_work, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
// malloc device working space of getrf
d_work_ = reinterpret_cast<T *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_));
for (size_t batch = 0; batch < batch_size_; ++batch) {
S *piv_output_addr = batch_piv_output_addr + batch * k_;
int *dev_piv = dev_batch_piv + batch * k_;
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnSgetrf(handle_, m_, n_, dev_work + batch * m_ * n_, lda_, d_work_, dev_piv, info_output_addr),
"cusolver lu fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnDgetrf(handle_, m_, n_, dev_work + batch * m_ * n_, lda_, d_work_, dev_piv, info_output_addr),
"cusolver lu fail");
} else if constexpr (std::is_same_v<T, utils::Complex<float>>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnCgetrf(handle_, m_, n_, reinterpret_cast<cuComplex *>(dev_work + batch * m_ * n_), lda_,
reinterpret_cast<cuComplex *>(d_work_), dev_piv, info_output_addr),
"cusolver lu fail");
} else {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnZgetrf(handle_, m_, n_, reinterpret_cast<cuDoubleComplex *>(dev_work + batch * m_ * n_), lda_,
reinterpret_cast<cuDoubleComplex *>(d_work_), dev_piv, info_output_addr),
"cusolver lu fail");
}
std::vector<int> host_permuted(k_, 0);
std::vector<int> host_pivots(k_, 0);
std::vector<S> host_p(k_, 0);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(host_pivots.data(), dev_piv, sizeof(int) * k_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync failed in LuGpuKernelMod::Launch copy pivots to host.");
// cal pivots && permutation major by row.
for (size_t i = 0; i < k_; ++i) {
host_pivots[i] -= 1;
host_permuted[i] = i;
}
for (size_t i = 0; i < k_; ++i) {
int tmp_value = host_permuted[i];
host_permuted[i] = host_permuted[host_pivots[i]];
host_permuted[host_pivots[i]] = tmp_value;
}
for (size_t i = 0; i < k_; ++i) {
host_p[i] = host_permuted[i];
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(piv_output_addr, host_p.data(), sizeof(S) * k_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync failed in LuGpuKernelMod::Launch copy pivots array.");
}
MatrixTranspose(dev_work, SizeToInt(input_elements_), m_, m_, batch_output_addr, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work_);
}
template <typename T, typename S>
void LuGpuKernelMod::LaunchKernel_Cublas(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *batch_input_addr = GetDeviceAddress<T>(inputs, kDim0);
T *batch_output_addr = GetDeviceAddress<T>(outputs, kDim0);
S *batch_piv_output_addr = nullptr;
if (pivot_on_) {
batch_piv_output_addr = GetDeviceAddress<S>(outputs, kDim1);
}
T *dev_transpose_work = GetDeviceAddress<T>(workspace, kDim1);
auto dev_batch_piv = GetDeviceAddress<int>(workspace, kDim2);
auto batch_lu_device_address = GetDeviceAddress<T *>(workspace, kDim3);
auto info = GetDeviceAddress<int>(workspace, kDim4);
std::vector<T *> batch_lu_address_data;
for (size_t i = 0; i < batch_size_; i++) {
batch_lu_address_data.emplace_back(dev_transpose_work + i * m_ * m_);
}
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(batch_lu_device_address, batch_lu_address_data.data(), sizeof(T *) * batch_size_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"LuGpuKernelMod cudaMemcpyAsync Fail");
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(cublasSetStream(cublas_handle_, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"For LuGpuKernelMod cublasSetStream Fail");
// Transpose input data from rowMajor to colMajor.
MatrixTranspose(batch_input_addr, SizeToInt(input_elements_), m_, m_, dev_transpose_work, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
if constexpr (std::is_same_v<T, float>) {
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasSgetrfBatched(cublas_handle_, m_, reinterpret_cast<float **>(batch_lu_device_address), m_, dev_batch_piv,
info, SizeToInt(batch_size_)),
"LuGpuKernelMod cublasSgetrfBatched Fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasDgetrfBatched(cublas_handle_, m_, reinterpret_cast<double **>(batch_lu_device_address), m_, dev_batch_piv,
info, SizeToInt(batch_size_)),
"LuGpuKernelMod cublasDgetrfBatched Fail");
} else if constexpr (std::is_same_v<T, utils::Complex<float>>) {
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasCgetrfBatched(cublas_handle_, m_, reinterpret_cast<cuComplex **>(batch_lu_device_address), m_,
dev_batch_piv, info, SizeToInt(batch_size_)),
"LuGpuKernelMod cublasCgetrfBatched Fail");
} else if constexpr (std::is_same_v<T, utils::Complex<double>>) {
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasZgetrfBatched(cublas_handle_, m_, reinterpret_cast<cuDoubleComplex **>(batch_lu_device_address), m_,
dev_batch_piv, info, SizeToInt(batch_size_)),
"LuGpuKernelMod cublasZgetrfBatched Fail");
} else {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', it's the input data type must be float32, float64, complex64 or complex128.";
}
MatrixTranspose(dev_transpose_work, SizeToInt(input_elements_), m_, m_, batch_output_addr, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
std::vector<int> host_permuted(batch_size_ * k_, 0);
std::vector<int> host_pivots(batch_size_ * k_, 0);
std::vector<S> host_p(batch_size_ * k_, 0);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(host_pivots.data(), dev_batch_piv, sizeof(int) * batch_size_ * k_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync failed in LuGpuKernelMod::Launch copy pivots to host.");
for (size_t i = 0; i < batch_size_; ++i) {
for (size_t j = 0; j < k_; ++j) {
host_permuted[i * k_ + j] = j;
host_pivots[i * k_ + j] -= 1;
}
for (size_t j = 0; j < k_; ++j) {
int tmp_value = host_permuted[i * k_ + j];
host_permuted[i * k_ + j] = host_permuted[i * k_ + host_pivots[i * k_ + j]];
host_permuted[i * k_ + host_pivots[i * k_ + j]] = tmp_value;
}
}
for (size_t i = 0; i < batch_size_; ++i) {
for (size_t j = 0; j < k_; ++j) {
host_p[i * k_ + j] = host_permuted[i * k_ + j];
}
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(batch_piv_output_addr, host_p.data(), sizeof(S) * batch_size_ * k_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync failed in LuGpuKernelMod::Launch copy pivots array.");
}
template <typename T, typename S>
bool LuGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
// If m_ / batch_size_ <= 128 :
// We use batched cublas api is faster by empiricism, for small matrices or large batch.
// Otherwise:
// We use no-batched cusolver api is faster by empiricism, For small batch sizes.
const size_t kNumber128 = 128;
if (m_ / batch_size_ <= kNumber128) {
LaunchKernel_Cublas<T, S>(inputs, workspace, outputs);
} else {
LaunchKernel_CuSolve<T, S>(inputs, workspace, outputs);
}
return true;
}
bool LuGpuKernelMod::CheckLuShape() {
constexpr size_t lu_min_dim = 1;
if (in_shape_.size() <= lu_min_dim) {
MS_LOG(ERROR) << kernel_name_ << " input shape is " << in_shape_.size() << " which is invalid.";
return false;
}
constexpr size_t lu_reverse_row_dim = 2;
lu_row_ = in_shape_.at(in_shape_.size() - lu_reverse_row_dim);
lu_col_ = in_shape_.at(in_shape_.size() - 1);
input_elements_ = std::accumulate(in_shape_.begin(), in_shape_.end(), size_t(1), std::multiplies<size_t>());
batch_size_ = lu_min_dim;
for (int batch = 0; batch < static_cast<int>(in_shape_.size() - lu_reverse_row_dim); ++batch) {
batch_size_ *= in_shape_.at(batch);
}
// set matrix row or col to be lead dimension
m_ = SizeToInt(lu_row_);
n_ = SizeToInt(lu_col_);
k_ = std::min(lu_row_, lu_col_);
lda_ = m_;
ldb_ = n_;
return true;
}
const std::vector<std::pair<KernelAttr, LuGpuKernelMod::KernelRunFunc>> &LuGpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, LuGpuKernelMod::KernelRunFunc>> func_list = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
&LuGpuKernelMod::LaunchKernel<float, int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32),
&LuGpuKernelMod::LaunchKernel<double, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeInt32),
&LuGpuKernelMod::LaunchKernel<utils::Complex<float>, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeInt32),
&LuGpuKernelMod::LaunchKernel<utils::Complex<double>, int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
&LuGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
&LuGpuKernelMod::LaunchKernel<double, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeInt64),
&LuGpuKernelMod::LaunchKernel<utils::Complex<float>, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeInt64),
&LuGpuKernelMod::LaunchKernel<utils::Complex<double>, int64_t>},
};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Lu, LuGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -14,244 +14,68 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LU_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LU_GPU_KERNEL_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_GPU_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include <string>
#include <algorithm>
#include <type_traits>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "include/common/utils/convert_utils.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class LUGpuKernelMod : public NativeGpuKernelMod {
class LuGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<LuGpuKernelMod> {
public:
LUGpuKernelMod() : is_null_input_(false) {}
~LUGpuKernelMod() = default;
LuGpuKernelMod() { ResetResource(); }
~LuGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cusolverDnSetStream failed");
T *batch_input_addr = GetDeviceAddress<T>(inputs, kDim0);
T *batch_output_addr = GetDeviceAddress<T>(outputs, kDim0);
int *batch_piv_output_addr = nullptr;
if (pivot_on_) {
batch_piv_output_addr = GetDeviceAddress<int>(outputs, kDim1);
}
int *batch_permutation_addr = GetDeviceAddress<int>(outputs, kDim2);
int *info_output_addr = GetDeviceAddress<int>(workspace, kDim0);
size_t *dev_transpose_shape = GetDeviceAddress<size_t>(workspace, kDim1);
size_t *dev_transpose_axis = GetDeviceAddress<size_t>(workspace, kDim2);
constexpr size_t shape_2d = 2;
size_t host_transpose_shape[shape_2d] = {m_, n_};
size_t host_transpose_axis[shape_2d] = {1, 0};
T *dev_transpose_work = GetDeviceAddress<T>(workspace, kDim3);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(dev_transpose_axis, host_transpose_axis, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"malloc input shape workspace failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_size_ * m_ * n_ * unit_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch.");
// 4. query working space of getrf
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnSgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
"cusolver query lu work size fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnDgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
"cusolver query lu work size fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
}
// 5. malloc device working space of getrf
d_work_ = reinterpret_cast<T *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_));
for (size_t batch = 0; batch < batch_size_; ++batch) {
T *output_addr = batch_output_addr + batch * m_ * n_;
int *permutation_addr = batch_permutation_addr + batch * k_ * k_;
int *piv_output_addr = batch_piv_output_addr + batch * k_;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(dev_transpose_shape, host_transpose_shape, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"malloc input shape workspace failed");
CalTranspose(m_ * n_, output_addr, dev_transpose_shape, dev_transpose_axis, shape_2d, dev_transpose_work,
reinterpret_cast<cudaStream_t>(stream_ptr));
// 6.lu factorization according to cuSolver api, outputs have been written to input's matrix.
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnSgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else if constexpr (std::is_same_v<T, double>) {
// 6.lu factorization according to cuSolver api, outputs have been written to input's matrix.
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnDgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
}
size_t host_wk_transpose_shape[shape_2d] = {n_, m_};
cudaMemcpyAsync(dev_transpose_shape, host_wk_transpose_shape, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalTranspose(m_ * n_, dev_transpose_work, dev_transpose_shape, dev_transpose_axis, shape_2d, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
std::vector<int> host_permuted(k_, 0);
std::vector<int> host_pivots(k_, 0);
std::vector<int> host_permutation(k_ * k_, 0);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(host_pivots.data(), piv_output_addr, sizeof(int) * k_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots to host.");
// cal pivots && permutation major by row.
for (size_t i = 0; i < k_; ++i) {
host_pivots[i] -= 1;
host_permuted[i] = i;
}
for (size_t i = 0; i < k_; ++i) {
int tmp_value = host_permuted[i];
host_permuted[i] = host_permuted[host_pivots[i]];
host_permuted[host_pivots[i]] = tmp_value;
}
// gpu default is P.A = LU, so here is col swap.
for (size_t i = 0; i < k_; ++i) {
host_permutation[host_permuted[i] * k_ + i] = 1;
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(permutation_addr, host_permutation.data(), sizeof(int) * k_ * k_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy permutation matrix.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(piv_output_addr, host_pivots.data(), sizeof(int) * k_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots array.");
}
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work_);
return true;
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
return true;
}
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
batch_size_ = 1;
auto shape_signed = inputs[kIndex0]->GetShapeVector();
auto in_shape = Convert2SizeT(shape_signed);
// 2. check input shape not null
is_null_input_ = CHECK_SHAPE_NULL(in_shape, kernel_name_, "input");
if (is_null_input_) {
InitSizeLists();
return KRET_OK;
}
// 3. calculate input size
if (!InitInputSize(in_shape)) {
MS_LOG(ERROR) << "For 'PureCholeskyGpuKernel', input shape init failed.";
return KRET_RESIZE_FAILED;
}
return KRET_OK;
}
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
};
return support_list;
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
bool InitInputSize(const std::vector<size_t> &in_shape) {
constexpr size_t lu_min_dim = 1;
if (in_shape.size() <= lu_min_dim) {
MS_LOG_EXCEPTION << kernel_name_ << " input shape is " << in_shape.size() << " which is invalid.";
}
constexpr size_t lu_reverse_row_dim = 2;
lu_row_ = in_shape.at(in_shape.size() - lu_reverse_row_dim);
lu_col_ = in_shape.at(in_shape.size() - 1);
batch_size_ = lu_min_dim;
for (int batch = 0; batch < static_cast<int>(in_shape.size() - lu_reverse_row_dim); ++batch) {
batch_size_ *= in_shape.at(batch);
}
// set matrix row or col to be lead dimension
m_ = SizeToInt(lu_row_);
n_ = SizeToInt(lu_col_);
k_ = std::min(lu_row_, lu_col_);
lda_ = m_;
ldb_ = n_;
InitSizeLists();
return true;
}
void ResetResource() noexcept;
void InitSizeLists() {
size_t input_size = batch_size_ * lu_row_ * lu_col_ * unit_size_;
input_size_list_.push_back(input_size);
bool CheckLuShape();
size_t output_size = batch_size_ * lu_row_ * lu_col_ * unit_size_;
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
size_t output_piv_size = 0;
if (pivot_on_) {
output_piv_size = batch_size_ * k_ * sizeof(int);
}
size_t output_permutation_size = batch_size_ * k_ * k_ * sizeof(int);
output_size_list_.resize(kDim3);
output_size_list_[kDim0] = output_size;
output_size_list_[kDim1] = output_piv_size;
output_size_list_[kDim2] = output_permutation_size;
template <typename T, typename S>
void LaunchKernel_CuSolve(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
// a device addr to place lu factor return code
workspace_size_list_.push_back(sizeof(int));
template <typename T, typename S>
void LaunchKernel_Cublas(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
// transpose 2d matrix scalar args workspace
constexpr size_t shape_2d = 2;
workspace_size_list_.push_back(shape_2d * sizeof(size_t));
workspace_size_list_.push_back(shape_2d * sizeof(size_t));
template <typename T>
void BufferSize(T *batch_output_addr, int *lwork);
// transpose workspace
workspace_size_list_.push_back(m_ * n_ * unit_size_);
}
size_t unit_size_{sizeof(T)};
bool is_null_input_{false};
bool pivot_on_{true};
std::vector<size_t> in_shape_;
size_t unit_size_{1};
size_t batch_size_{1};
size_t input_elements_{};
size_t lu_row_{0};
size_t lu_col_{0};
size_t k_{0};
@ -260,12 +84,11 @@ class LUGpuKernelMod : public NativeGpuKernelMod {
size_t lda_{0};
size_t ldb_{0};
int lwork_{0};
bool pivot_on_{true};
T *d_work_{nullptr};
void *cuda_stream_{nullptr};
cusolverDnHandle_t handle_{nullptr};
bool is_null_input_;
cublasHandle_t cublas_handle_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LU_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_GPU_KERNEL_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/lu_scipy_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUGpuKernelMod, double)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,269 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_SCIPY_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_SCIPY_GPU_KERNEL_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <map>
#include <string>
#include <algorithm>
#include <type_traits>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "include/common/utils/convert_utils.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class LUGpuKernelMod : public NativeGpuKernelMod {
public:
LUGpuKernelMod() : is_null_input_(false) {}
~LUGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cusolverDnSetStream failed");
T *batch_input_addr = GetDeviceAddress<T>(inputs, kDim0);
T *batch_output_addr = GetDeviceAddress<T>(outputs, kDim0);
int *batch_piv_output_addr = nullptr;
if (pivot_on_) {
batch_piv_output_addr = GetDeviceAddress<int>(outputs, kDim1);
}
int *batch_permutation_addr = GetDeviceAddress<int>(outputs, kDim2);
int *info_output_addr = GetDeviceAddress<int>(workspace, kDim0);
size_t *dev_transpose_shape = GetDeviceAddress<size_t>(workspace, kDim1);
size_t *dev_transpose_axis = GetDeviceAddress<size_t>(workspace, kDim2);
constexpr size_t shape_2d = 2;
size_t host_transpose_shape[shape_2d] = {m_, n_};
size_t host_transpose_axis[shape_2d] = {1, 0};
T *dev_transpose_work = GetDeviceAddress<T>(workspace, kDim3);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(dev_transpose_axis, host_transpose_axis, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"malloc input shape workspace failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_size_ * m_ * n_ * unit_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch.");
// 4. query working space of getrf
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnSgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
"cusolver query lu work size fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnDgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
"cusolver query lu work size fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
}
// 5. malloc device working space of getrf
d_work_ = reinterpret_cast<T *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_));
for (size_t batch = 0; batch < batch_size_; ++batch) {
T *output_addr = batch_output_addr + batch * m_ * n_;
int *permutation_addr = batch_permutation_addr + batch * k_ * k_;
int *piv_output_addr = batch_piv_output_addr + batch * k_;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(dev_transpose_shape, host_transpose_shape, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"malloc input shape workspace failed");
CalTranspose(m_ * n_, output_addr, dev_transpose_shape, dev_transpose_axis, shape_2d, dev_transpose_work,
reinterpret_cast<cudaStream_t>(stream_ptr));
// 6.lu factorization according to cuSolver api, outputs have been written to input's matrix.
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnSgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else if constexpr (std::is_same_v<T, double>) {
// 6.lu factorization according to cuSolver api, outputs have been written to input's matrix.
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnDgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
}
size_t host_wk_transpose_shape[shape_2d] = {n_, m_};
cudaMemcpyAsync(dev_transpose_shape, host_wk_transpose_shape, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalTranspose(m_ * n_, dev_transpose_work, dev_transpose_shape, dev_transpose_axis, shape_2d, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
std::vector<int> host_permuted(k_, 0);
std::vector<int> host_pivots(k_, 0);
std::vector<int> host_permutation(k_ * k_, 0);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(host_pivots.data(), piv_output_addr, sizeof(int) * k_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots to host.");
// cal pivots && permutation major by row.
for (size_t i = 0; i < k_; ++i) {
host_pivots[i] -= 1;
host_permuted[i] = i;
}
for (size_t i = 0; i < k_; ++i) {
int tmp_value = host_permuted[i];
host_permuted[i] = host_permuted[host_pivots[i]];
host_permuted[host_pivots[i]] = tmp_value;
}
// gpu default is P.A = LU, so here is col swap.
for (size_t i = 0; i < k_; ++i) {
host_permutation[host_permuted[i] * k_ + i] = 1;
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(permutation_addr, host_permutation.data(), sizeof(int) * k_ * k_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy permutation matrix.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(piv_output_addr, host_pivots.data(), sizeof(int) * k_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernelMod::Launch copy pivots array.");
}
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work_);
return true;
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
return true;
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
batch_size_ = 1;
auto shape_signed = inputs[kIndex0]->GetShapeVector();
auto in_shape = Convert2SizeT(shape_signed);
// 2. check input shape not null
is_null_input_ = CHECK_SHAPE_NULL(in_shape, kernel_name_, "input");
if (is_null_input_) {
InitSizeLists();
return KRET_OK;
}
// 3. calculate input size
if (!InitInputSize(in_shape)) {
MS_LOG(ERROR) << "For 'PureCholeskyGpuKernel', input shape init failed.";
return KRET_RESIZE_FAILED;
}
return KRET_OK;
}
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
};
return support_list;
}
private:
bool InitInputSize(const std::vector<size_t> &in_shape) {
constexpr size_t lu_min_dim = 1;
if (in_shape.size() <= lu_min_dim) {
MS_LOG_EXCEPTION << kernel_name_ << " input shape is " << in_shape.size() << " which is invalid.";
}
constexpr size_t lu_reverse_row_dim = 2;
lu_row_ = in_shape.at(in_shape.size() - lu_reverse_row_dim);
lu_col_ = in_shape.at(in_shape.size() - 1);
batch_size_ = lu_min_dim;
for (int batch = 0; batch < static_cast<int>(in_shape.size() - lu_reverse_row_dim); ++batch) {
batch_size_ *= in_shape.at(batch);
}
// set matrix row or col to be lead dimension
m_ = SizeToInt(lu_row_);
n_ = SizeToInt(lu_col_);
k_ = std::min(lu_row_, lu_col_);
lda_ = m_;
ldb_ = n_;
InitSizeLists();
return true;
}
void InitSizeLists() {
size_t input_size = batch_size_ * lu_row_ * lu_col_ * unit_size_;
input_size_list_.push_back(input_size);
size_t output_size = batch_size_ * lu_row_ * lu_col_ * unit_size_;
size_t output_piv_size = 0;
if (pivot_on_) {
output_piv_size = batch_size_ * k_ * sizeof(int);
}
size_t output_permutation_size = batch_size_ * k_ * k_ * sizeof(int);
output_size_list_.resize(kDim3);
output_size_list_[kDim0] = output_size;
output_size_list_[kDim1] = output_piv_size;
output_size_list_[kDim2] = output_permutation_size;
// a device addr to place lu factor return code
workspace_size_list_.push_back(sizeof(int));
// transpose 2d matrix scalar args workspace
constexpr size_t shape_2d = 2;
workspace_size_list_.push_back(shape_2d * sizeof(size_t));
workspace_size_list_.push_back(shape_2d * sizeof(size_t));
// transpose workspace
workspace_size_list_.push_back(m_ * n_ * unit_size_);
}
size_t unit_size_{sizeof(T)};
size_t batch_size_{1};
size_t lu_row_{0};
size_t lu_col_{0};
size_t k_{0};
size_t m_{0};
size_t n_{0};
size_t lda_{0};
size_t ldb_{0};
int lwork_{0};
bool pivot_on_{true};
T *d_work_{nullptr};
cusolverDnHandle_t handle_{nullptr};
bool is_null_input_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LU_SCIPY_GPU_KERNEL_H_

View File

@ -1245,6 +1245,7 @@ GVAR_DEF(PrimitivePtr, kPrimGer, std::make_shared<Primitive>("Ger"));
GVAR_DEF(PrimitivePtr, kPrimCeil, std::make_shared<Primitive>("Ceil"));
GVAR_DEF(PrimitivePtr, kPrimDiagonal, std::make_shared<Primitive>(kDiagonal));
GVAR_DEF(PrimitivePtr, kPrimTrunc, std::make_shared<Primitive>("Trunc"));
GVAR_DEF(PrimitivePtr, kPrimLu, std::make_shared<Primitive>("Lu"));
GVAR_DEF(PrimitivePtr, kPrimLuSolve, std::make_shared<Primitive>("LuSolve"));
GVAR_DEF(PrimitivePtr, kPrimMatrixSolve, std::make_shared<Primitive>("MatrixSolve"));
GVAR_DEF(PrimitivePtr, kPrimTridiagonalSolve, std::make_shared<Primitive>(kTridiagonalSolve));

View File

@ -15,68 +15,67 @@
*/
#include "ops/lu.h"
#include <algorithm>
#include "ops/op_utils.h"
#include "mindapi/ir/type.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kLUInputsNum = 1;
constexpr size_t kXDim = 2;
constexpr size_t kLastDim = 1;
constexpr size_t kPenultimateDim = 2;
abstract::TupleShapePtr LUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
abstract::TupleShapePtr LuInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto x_shape = x_shape_map[kShape];
auto x_output = std::make_shared<abstract::Shape>(x_shape);
if (IsDynamicRank(x_shape)) {
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{x_output, x_output, x_output});
constexpr int64_t number1 = 1;
constexpr int64_t number2 = 2;
const int64_t input_num = 1;
const int64_t rank = 2;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num,
prim_name);
auto input_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kInputIndex0);
auto input_shape = input_shape_ptr->shape();
if (IsDynamicRank(input_shape)) {
abstract::ShapePtr rank_shape = std::make_shared<abstract::Shape>(ShapeVector({-2}));
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{rank_shape, rank_shape});
}
size_t x_shape_size = x_shape.size();
if (x_shape_size < kXDim) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "',"
<< " the dimension of hashmap must be greater than or equal to 2, but got: "
<< x_shape_size << ".";
std::vector<int64_t> p_shape(input_shape.begin(), (input_shape.end() - number1));
abstract::ShapePtr p_shape_ptr = std::make_shared<abstract::Shape>(p_shape);
auto input_rank = SizeToLong(input_shape.size());
CheckAndConvertUtils::CheckInteger("input rank", input_rank, kGreaterEqual, rank, prim_name);
int64_t size1 = input_shape[input_shape.size() - number1];
int64_t size2 = input_shape[input_shape.size() - number2];
if (size1 != size2) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', input_shape[-1] and input_shape[-2] must be same, but got " << size1 << " vs "
<< size2;
}
auto k_shape = std::min(x_shape[x_shape_size - kLastDim], x_shape[x_shape_size - kPenultimateDim]);
ShapeVector top_k_shape(x_shape.begin(), x_shape.end() - kPenultimateDim);
ShapeVector pivots_shape = top_k_shape;
pivots_shape.push_back(k_shape);
ShapeVector permutation_shape = pivots_shape;
permutation_shape.push_back(k_shape);
auto pivots_output = std::make_shared<abstract::Shape>(pivots_shape);
auto permutation_output = std::make_shared<abstract::Shape>(permutation_shape);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{x_output, pivots_output, permutation_output});
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{input_shape_ptr, p_shape_ptr});
}
TuplePtr LUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto x_type = input_args[0]->BuildType();
return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, kInt32, kInt32});
TypePtr LuInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> lu_types = {kFloat32, kFloat64, kComplex64, kComplex128};
auto input_type = input_args[kInputIndex0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("input type", input_type, lu_types, prim->name());
const std::set<TypePtr> out_valid_types = {kInt32, kInt64};
ValuePtr out_type_value = prim->GetAttr("output_idx_type");
TypePtr type = dyn_cast<Type>(out_type_value);
(void)CheckAndConvertUtils::CheckTypeValid("p type", type, out_valid_types, prim->name());
return std::make_shared<Tuple>(std::vector<TypePtr>{input_type, type});
}
} // namespace
AbstractBasePtr LUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
MIND_API_OPERATOR_IMPL(Lu, BaseOperator);
AbstractBasePtr LuInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kLUInputsNum, primitive->name());
auto infer_type = LUInferType(primitive, input_args);
auto infer_shape = LUInferShape(primitive, input_args);
auto infer_type = LuInferType(primitive, input_args);
auto infer_shape = LuInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(LU, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(LU, prim::kPrimLU, LUInfer, nullptr, true);
REGISTER_PRIMITIVE_EVAL_IMPL(Lu, prim::kPrimLu, LuInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -18,25 +18,24 @@
#define MINDSPORE_CORE_OPS_LU_H_
#include <map>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLU = "LU";
class MIND_API LU : public BaseOperator {
constexpr auto kNameLu = "Lu";
class MIND_API Lu : public BaseOperator {
public:
MIND_API_BASE_MEMBER(LU);
LU() : BaseOperator(kNameLU) { InitIOName({"x"}, {"lu", "pivots", "permutation"}); }
MIND_API_BASE_MEMBER(Lu);
Lu() : BaseOperator(kNameLu) { InitIOName({"input"}, {"lu", "p"}); }
};
abstract::AbstractBasePtr LUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
abstract::AbstractBasePtr LuInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimLUPtr = std::shared_ptr<LU>;
using PrimLuPtr = std::shared_ptr<Lu>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LU_H_

View File

@ -0,0 +1,82 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/lu_scipy.h"
#include <algorithm>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kLUInputsNum = 1;
constexpr size_t kXDim = 2;
constexpr size_t kLastDim = 1;
constexpr size_t kPenultimateDim = 2;
abstract::TupleShapePtr LUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto x_shape = x_shape_map[kShape];
auto x_output = std::make_shared<abstract::Shape>(x_shape);
if (IsDynamicRank(x_shape)) {
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{x_output, x_output, x_output});
}
size_t x_shape_size = x_shape.size();
if (x_shape_size < kXDim) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "',"
<< " the dimension of hashmap must be greater than or equal to 2, but got: "
<< x_shape_size << ".";
}
auto k_shape = std::min(x_shape[x_shape_size - kLastDim], x_shape[x_shape_size - kPenultimateDim]);
ShapeVector top_k_shape(x_shape.begin(), x_shape.end() - kPenultimateDim);
ShapeVector pivots_shape = top_k_shape;
pivots_shape.push_back(k_shape);
ShapeVector permutation_shape = pivots_shape;
permutation_shape.push_back(k_shape);
auto pivots_output = std::make_shared<abstract::Shape>(pivots_shape);
auto permutation_output = std::make_shared<abstract::Shape>(permutation_shape);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{x_output, pivots_output, permutation_output});
}
TuplePtr LUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto x_type = input_args[0]->BuildType();
return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, kInt32, kInt32});
}
} // namespace
AbstractBasePtr LUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kLUInputsNum, primitive->name());
auto infer_type = LUInferType(primitive, input_args);
auto infer_shape = LUInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(LU, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(LU, prim::kPrimLU, LUInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/
#ifndef MINDSPORE_CORE_OPS_LU_SCIPY_H_
#define MINDSPORE_CORE_OPS_LU_SCIPY_H_
#include <map>
#include <vector>
#include <set>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLU = "LU";
class MIND_API LU : public BaseOperator {
public:
MIND_API_BASE_MEMBER(LU);
LU() : BaseOperator(kNameLU) { InitIOName({"x"}, {"lu", "pivots", "permutation"}); }
};
abstract::AbstractBasePtr LUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimLUPtr = std::shared_ptr<LU>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LU_SCIPY_H_

View File

@ -6286,6 +6286,52 @@ class MatrixSolveLs(Primitive):
validator.check_value_type('fast', fast, [bool], self.name)
class Lu(Primitive):
"""
Computes the LU decomposition of one or more square matrices.
Args:
output_idx_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`.
Default: `mindspore.dtype.int32`.
Inputs:
- **input** (Tensor) - A tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
matrices of size `[M, M]`, with data type float32, float64, complex64, complex128.
Outputs:
- **lu** (Tensor) - A tensor of shape `[..., M, M]` whose strictly lower triangular part denotes the lower
triangular factor `L` with unit diagonal. Upper triangular part denotes the upper triangular factor `U`.
- **p** (Tensor) - Permutation of the rows encoded as a list of indices in `0..M-1`, shape is `[..., M]`.
Raises:
TypeError: If the dtype of `input` is not one of the following dtype:
float32, float64, complex64, complex128.
TypeError: If `output_idx_type` is neither int32 nor int64.
ValueError: If `input` rank is less than 2.
ValueError: If input[-1] is not equal to input[-2].
Supported Platforms:
``GPU``
Examples:
>>> input = Tensor(np.array([[2.5,3.1,3.5], [4.7,1.9,0.2], [1.1,3.6,2.0]]), mindspore.float32)
>>> lu, p = ops.Lu(output_idx_type=mindspore.int32)(input)
>>> print(lu)
[[4.7 1.9 0.2 ]
[0.23404257 3.155319 1.9531915 ]
[0.5319149 0.6621713 2.1002696 ]]
>>> print(p)
[1 2 0]
"""
@prim_attr_register
def __init__(self, output_idx_type):
super().__init__(name="Lu")
self.init_prim_io_names(inputs=['input'], outputs=['lu', 'p'])
validator.check_type_name("output_idx_type", output_idx_type, [mstype.int32, mstype.int64], self.name)
self.add_prim_attr('output_idx_type', output_idx_type)
class LuSolve(Primitive):
r"""
Return the solution of the linear equation :math:`Ax = b` .

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,83 +12,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from typing import Generic
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.numpy as mnp
import mindspore.common.dtype as mstype
from mindspore.ops import PrimitiveWithInfer
from mindspore.ops import prim_attr_register
import scipy as scp
import numpy as np
import pytest
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
class LU(PrimitiveWithInfer):
"""
LU decomposition with partial pivoting
P.A = L.U
"""
@prim_attr_register
def __init__(self):
super().__init__(name="LU")
self.init_prim_io_names(inputs=['x'], outputs=['lu', 'pivots', 'permutation'])
def __infer__(self, x):
x_shape = list(x['shape'])
x_dtype = x['dtype']
pivots_shape = []
permutation_shape = []
ndim = len(x_shape)
if ndim == 0:
pivots_shape = x_shape
permutation_shape = x_shape
elif ndim == 1:
pivots_shape = x_shape[:-1]
# permutation_shape = x_shape[:-1]
else:
pivots_shape = x_shape[-2:-1]
# permutation_shape = x_shape[-2:-1]
output = {
'shape': (x_shape, pivots_shape, permutation_shape),
'dtype': (x_dtype, mstype.int32, mstype.int32),
'value': None
}
return output
import scipy as scp
import mindspore.nn as nn
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.ops.operations import math_ops as P
class LuNet(nn.Cell):
def __init__(self):
def __init__(self, output_idx_type=mstype.int32):
super(LuNet, self).__init__()
self.lu = LU()
self.lu = P.Lu(output_idx_type=output_idx_type)
def construct(self, a):
return self.lu(a)
@pytest.mark.platform_x86_gpu
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_lu_net(n: int, dtype: Generic):
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_lu_dtype_float32():
"""
Feature: ALL To ALL
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
Feature: Lu gpu TEST.
Description: float32 test case for Lu
Expectation: the result match to scp
"""
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
expect, _ = scp.linalg.lu_factor(a)
mscp_lu_net = LuNet()
# mindspore tensor is row major but gpu cusolver is col major, so we should transpose it.
tensor_a = Tensor(a)
tensor_a = mnp.transpose(tensor_a)
output, _, _ = mscp_lu_net(tensor_a)
# mindspore tensor is row major but gpu cusolver is col major, so we should transpose it.
output = mnp.transpose(output)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_np = np.array([[2.5, 3.1, 3.5], [4.7, 1.9, 0.2], [1.1, 3.6, 2.0]])
expect, _ = scp.linalg.lu_factor(x_np)
input_x = Tensor(x_np.astype(np.float32))
net = LuNet(mstype.int32)
lu, _ = net(input_x)
rtol = 1.e-4
atol = 1.e-4
assert np.allclose(expect, lu.asnumpy(), rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_lu_dtype_float64():
"""
Feature: Lu gpu TEST.
Description: float64 test case for Lu
Expectation: the result match to scp
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x_np = np.array([[3.5, 6.5, 3.1], [4.7, 1.9, 6.2], [1.5, 4.8, 2.3]])
expect, _ = scp.linalg.lu_factor(x_np)
input_x = Tensor(x_np.astype(np.float64))
net = LuNet(mstype.int64)
lu, _ = net(input_x)
rtol = 1.e-5
atol = 1.e-5
assert np.allclose(expect, output.asnumpy(), rtol=rtol, atol=atol)
assert np.allclose(expect, lu.asnumpy(), rtol=rtol, atol=atol)