!28988 fix lu batched for gpu && cpu backend
Merge pull request !28988 from zhuzhongrui/r1.6
This commit is contained in:
commit
b73f77a4da
|
@ -30,7 +30,6 @@ constexpr size_t kLUOutputsNum = 3;
|
|||
constexpr size_t kLuIndex = 0;
|
||||
constexpr size_t kPivotsIndex = 1;
|
||||
constexpr size_t kPermutationIndex = 2;
|
||||
constexpr size_t kLUDefaultShape = 1;
|
||||
constexpr size_t kRowIndex = 2;
|
||||
constexpr size_t kColIndex = 1;
|
||||
constexpr int kZeroThreshold = INT32_MIN;
|
||||
|
@ -38,16 +37,20 @@ constexpr int kZeroThreshold = INT32_MIN;
|
|||
|
||||
template <typename T>
|
||||
void LUCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
|
||||
if (shape.empty()) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << "shape is invalid.";
|
||||
constexpr size_t lu_min_dim = 1;
|
||||
constexpr size_t lu_max_dim = 3;
|
||||
if (shape.size() < lu_min_dim || shape.size() > lu_max_dim) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << "shape is " << shape.size() << " which is invalid.";
|
||||
}
|
||||
if (shape.size() == kLUDefaultShape) {
|
||||
*row = shape.front();
|
||||
*col = 1;
|
||||
} else {
|
||||
*row = shape.at(shape.size() - kRowIndex);
|
||||
*col = shape.at(shape.size() - kColIndex);
|
||||
if (shape.size() == lu_max_dim) {
|
||||
batch_ = shape.front();
|
||||
*row = shape.at(lu_min_dim);
|
||||
*col = shape.at(lu_max_dim - 1);
|
||||
return;
|
||||
}
|
||||
batch_ = 1;
|
||||
*row = shape.front();
|
||||
*col = shape.at(lu_min_dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -84,7 +87,8 @@ T LUCPUKernel<T>::GetPermutatedValue(const T *lu_value, const std::vector<int> &
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool LUCPUKernel<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, size_t k, size_t rows) {
|
||||
bool LUCPUKernel<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, int *pivots, size_t k,
|
||||
size_t rows) {
|
||||
T max_major_value = static_cast<T>(kZeroThreshold);
|
||||
int max_major_index = 0;
|
||||
for (size_t i = k; i < rows; ++i) {
|
||||
|
@ -98,7 +102,7 @@ bool LUCPUKernel<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *const
|
|||
int per_k = per_value->at(k);
|
||||
(*per_value)[k] = per_value->at(max_major_index);
|
||||
(*per_value)[max_major_index] = per_k;
|
||||
pivots_[k] = max_major_index;
|
||||
pivots[k] = max_major_index;
|
||||
return max_major_value != static_cast<T>(kZeroThreshold);
|
||||
}
|
||||
|
||||
|
@ -114,102 +118,108 @@ bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
// input matrix of (m,n) PA = LU
|
||||
T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
|
||||
T *lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
|
||||
// pivots permutation value
|
||||
pivots_ = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
|
||||
// permutation matrix value
|
||||
int *permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr);
|
||||
T *batch_a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
|
||||
T *batch_lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
|
||||
batch_pivots_ = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
|
||||
int *batch_permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr);
|
||||
T *lu_ori_wk = reinterpret_cast<T *>(workspace[kLuIndex]->addr);
|
||||
T *lu_trans_wk = reinterpret_cast<T *>(workspace[kPivotsIndex]->addr);
|
||||
// init pivots
|
||||
std::vector<int> per_value(pivots_row_, 0);
|
||||
for (size_t i = 0; i < pivots_row_; ++i) {
|
||||
per_value[i] = i;
|
||||
}
|
||||
|
||||
// 1. memcpy input to output, do full lu inplace.
|
||||
(void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T));
|
||||
|
||||
int s = std::min(a_row_, a_col_);
|
||||
// 2. do lu decompose inplace
|
||||
for (int k = 0; k < s; ++k) {
|
||||
// 2.1 choose major element of current col if return false means current col elements are all zero, just continue.
|
||||
if (!UpdateMajorPermutation(lu_value, &per_value, k, lu_row_)) {
|
||||
continue;
|
||||
}
|
||||
// 2.2 major element x --> (1/x), get inplace origin lu matrix value.
|
||||
T value = static_cast<T>(1.0 / GetPermutatedValue(lu_value, per_value, k, k));
|
||||
// 2.3 change major col values
|
||||
for (size_t i = k + 1; i < lu_row_; ++i) {
|
||||
T y = static_cast<T>(GetPermutatedValue(lu_value, per_value, i, k) * value);
|
||||
// set inplace new lu matrix value.
|
||||
SetPermutatedValue(lu_value, per_value, i, k, y);
|
||||
for (size_t batch = 0; batch < batch_; ++batch) {
|
||||
T *a_value = batch_a_value + batch * a_row_ * a_col_;
|
||||
T *lu_value = batch_lu_value + batch * lu_row_ * lu_col_;
|
||||
// pivots permutation value
|
||||
int *pivots = batch_pivots_ + batch * pivots_row_ * pivots_col_;
|
||||
// permutation matrix value
|
||||
int *permutation_value = batch_permutation_value + batch * permutation_row_ * permutation_col_;
|
||||
// init pivots
|
||||
std::vector<int> per_value(pivots_col_, 0);
|
||||
for (size_t i = 0; i < pivots_col_; ++i) {
|
||||
per_value[i] = i;
|
||||
}
|
||||
// 1. memcpy input to output, do full lu inplace.
|
||||
(void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T));
|
||||
|
||||
// 2.4 Gauss elimination core
|
||||
for (size_t i = k + 1; i < lu_row_; ++i) {
|
||||
for (size_t j = k + 1; j < lu_col_; ++j) {
|
||||
T y =
|
||||
static_cast<T>(GetPermutatedValue(lu_value, per_value, i, j) -
|
||||
GetPermutatedValue(lu_value, per_value, i, k) * GetPermutatedValue(lu_value, per_value, k, j));
|
||||
SetPermutatedValue(lu_value, per_value, i, j, y);
|
||||
int s = std::min(a_row_, a_col_);
|
||||
// 2. do lu decompose inplace
|
||||
for (int k = 0; k < s; ++k) {
|
||||
// 2.1 choose major element of current col if return false means current col elements are all zero, just continue.
|
||||
if (!UpdateMajorPermutation(lu_value, &per_value, pivots, k, lu_row_)) {
|
||||
continue;
|
||||
}
|
||||
// 2.2 major element x --> (1/x), get inplace origin lu matrix value.
|
||||
T value = static_cast<T>(1.0 / GetPermutatedValue(lu_value, per_value, k, k));
|
||||
// 2.3 change major col values
|
||||
for (size_t i = k + 1; i < lu_row_; ++i) {
|
||||
T y = static_cast<T>(GetPermutatedValue(lu_value, per_value, i, k) * value);
|
||||
// set inplace new lu matrix value.
|
||||
SetPermutatedValue(lu_value, per_value, i, k, y);
|
||||
}
|
||||
|
||||
// 2.4 Gauss elimination core
|
||||
for (size_t i = k + 1; i < lu_row_; ++i) {
|
||||
for (size_t j = k + 1; j < lu_col_; ++j) {
|
||||
T y = static_cast<T>(GetPermutatedValue(lu_value, per_value, i, j) -
|
||||
GetPermutatedValue(lu_value, per_value, i, k) *
|
||||
GetPermutatedValue(lu_value, per_value, k, j));
|
||||
SetPermutatedValue(lu_value, per_value, i, j, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. calculate final lu by permutation list
|
||||
std::unordered_map<int, std::pair<int, bool>> pivots_map;
|
||||
for (int i = 0; i < static_cast<int>(lu_row_); ++i) {
|
||||
pivots_map[per_value[i]] = {i, false};
|
||||
}
|
||||
int pivots_count = 0;
|
||||
for (const auto &pivot : pivots_map) {
|
||||
pivots_count++;
|
||||
int key = pivot.first;
|
||||
int index = pivot.second.first;
|
||||
bool is_visited = pivot.second.second;
|
||||
if (is_visited || index == (pivots_count - 1)) {
|
||||
continue;
|
||||
// 3. calculate final lu by permutation list
|
||||
std::unordered_map<int, std::pair<int, bool>> pivots_map;
|
||||
for (int i = 0; i < static_cast<int>(lu_row_); ++i) {
|
||||
pivots_map[per_value[i]] = {i, false};
|
||||
}
|
||||
int pivots_count = 0;
|
||||
for (const auto &pivot : pivots_map) {
|
||||
pivots_count++;
|
||||
int key = pivot.first;
|
||||
int index = pivot.second.first;
|
||||
bool is_visited = pivot.second.second;
|
||||
if (is_visited || index == (pivots_count - 1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
T *lu_ori_row = lu_value + index * lu_col_;
|
||||
T *lu_trans_row = lu_value + key * lu_col_;
|
||||
// copy ori data to trans lu
|
||||
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
|
||||
// copy new data to ori data ptr
|
||||
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_trans_row, lu_col_ * sizeof(T));
|
||||
// update pivot map
|
||||
pivots_map[key] = {index, true};
|
||||
// put ori data which stored in workspace to mapped new place
|
||||
is_visited = pivots_map[index].second;
|
||||
while (!is_visited) {
|
||||
key = index;
|
||||
index = pivots_map[key].first;
|
||||
is_visited = pivots_map[key].second;
|
||||
lu_ori_row = lu_value + index * lu_col_;
|
||||
T *tmp_wk = lu_trans_wk;
|
||||
lu_trans_wk = lu_ori_wk;
|
||||
lu_ori_wk = tmp_wk;
|
||||
// copy new ori data to trans workspace
|
||||
T *lu_ori_row = lu_value + index * lu_col_;
|
||||
T *lu_trans_row = lu_value + key * lu_col_;
|
||||
// copy ori data to trans lu
|
||||
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
|
||||
// copy new data to ori data place
|
||||
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_ori_wk, lu_col_ * sizeof(T));
|
||||
// copy new data to ori data ptr
|
||||
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_trans_row, lu_col_ * sizeof(T));
|
||||
// update pivot map
|
||||
pivots_map[key] = {index, true};
|
||||
// put ori data which stored in workspace to mapped new place
|
||||
is_visited = pivots_map[index].second;
|
||||
while (!is_visited) {
|
||||
key = index;
|
||||
index = pivots_map[key].first;
|
||||
is_visited = pivots_map[key].second;
|
||||
lu_ori_row = lu_value + index * lu_col_;
|
||||
T *tmp_wk = lu_trans_wk;
|
||||
lu_trans_wk = lu_ori_wk;
|
||||
lu_ori_wk = tmp_wk;
|
||||
// copy new ori data to trans workspace
|
||||
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
|
||||
// copy new data to ori data place
|
||||
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_ori_wk, lu_col_ * sizeof(T));
|
||||
pivots_map[key] = {index, true};
|
||||
}
|
||||
}
|
||||
|
||||
// 4. calculate final permutation matrix
|
||||
// for PA = LU get: base + row * permutation_row_ + col
|
||||
// for A = PLU get: base + col * permutation_row_ + row
|
||||
// here, we do A = PLU which is same as scipy.
|
||||
size_t count = permutation_col_ * permutation_row_ * sizeof(int);
|
||||
(void)memset_s(reinterpret_cast<void *>(permutation_value), count, 0, count);
|
||||
for (size_t i = 0; i < pivots_col_; ++i) {
|
||||
int position = per_value[i];
|
||||
int *per_addr = permutation_value + position * permutation_row_ + i;
|
||||
*per_addr = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// 4. calculate final permutation matrix
|
||||
// for PA = LU get: base + row * permutation_row_ + col
|
||||
// for A = PLU get: base + col * permutation_row_ + row
|
||||
// here, we do A = PLU which is same as scipy.
|
||||
size_t count = permutation_col_ * permutation_row_ * sizeof(int);
|
||||
(void)memset_s(reinterpret_cast<void *>(permutation_value), count, 0, count);
|
||||
for (size_t i = 0; i < pivots_row_; ++i) {
|
||||
int position = per_value[i];
|
||||
int *per_addr = permutation_value + position * permutation_row_ + i;
|
||||
*per_addr = 1;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -36,8 +36,9 @@ class LUCPUKernel : public CPUKernel {
|
|||
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override;
|
||||
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
|
||||
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, size_t k, size_t rows);
|
||||
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, int *pivots, size_t k, size_t rows);
|
||||
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
|
||||
size_t batch_{1};
|
||||
size_t a_row_{1};
|
||||
size_t a_col_{1};
|
||||
size_t lu_row_{1};
|
||||
|
@ -47,7 +48,7 @@ class LUCPUKernel : public CPUKernel {
|
|||
size_t permutation_row_{1};
|
||||
size_t permutation_col_{1};
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
int *pivots_{nullptr};
|
||||
int *batch_pivots_{nullptr};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(LU,
|
||||
|
|
|
@ -26,16 +26,10 @@
|
|||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kLuInputsNum = 1;
|
||||
constexpr size_t kInputIndex = 0;
|
||||
constexpr size_t kLuOutputsNum = 1;
|
||||
constexpr size_t kOutputIndex = 0;
|
||||
constexpr size_t kLuDefaultShape = 1;
|
||||
constexpr size_t kLuNormalShape = 2;
|
||||
|
||||
template <typename T>
|
||||
class LUGpuKernel : public GpuKernel {
|
||||
|
@ -53,54 +47,111 @@ class LUGpuKernel : public GpuKernel {
|
|||
}
|
||||
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cusolverDnSetStream failed");
|
||||
auto input_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
int *piv_output_addr = nullptr;
|
||||
T *batch_input_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
T *batch_output_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
int *batch_piv_output_addr = nullptr;
|
||||
if (pivot_on_) {
|
||||
piv_output_addr = GetDeviceAddress<int>(outputs, kDim1);
|
||||
batch_piv_output_addr = GetDeviceAddress<int>(outputs, kDim1);
|
||||
}
|
||||
int *batch_permutation_addr = GetDeviceAddress<int>(outputs, kDim2);
|
||||
int *info_output_addr = GetDeviceAddress<int>(workspace, kDim0);
|
||||
|
||||
auto info_output_addr = GetDeviceAddress<int>(outputs, kDim2);
|
||||
size_t *dev_transpose_shape = GetDeviceAddress<size_t>(workspace, kDim1);
|
||||
size_t *dev_transpose_axis = GetDeviceAddress<size_t>(workspace, kDim2);
|
||||
constexpr size_t shape_2d = 2;
|
||||
size_t host_transpose_shape[shape_2d] = {m_, n_};
|
||||
size_t host_transpose_axis[shape_2d] = {1, 0};
|
||||
T *dev_transpose_work = GetDeviceAddress<T>(workspace, kDim3);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(dev_transpose_axis, host_transpose_axis, shape_2d * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"malloc input shape workspace failed");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_ * m_ * n_ * unit_size_,
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync failed in LUGpuKernel::Launch.");
|
||||
|
||||
// 4. query working space of getrf
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
|
||||
cusolverDnSgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_),
|
||||
cusolverDnSgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
|
||||
"cusolver query lu work size fail");
|
||||
|
||||
if (cudaMalloc(reinterpret_cast<void **>(&d_work_), unit_size_ * lwork_) != cudaSuccess) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cusolver malloc work size fail";
|
||||
}
|
||||
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnSgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr),
|
||||
"cusolver lu fail");
|
||||
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
|
||||
cusolverDnDgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_),
|
||||
cusolverDnDgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
|
||||
"cusolver query lu work size fail");
|
||||
// 5. malloc device working space of getrf
|
||||
|
||||
if (cudaMalloc(reinterpret_cast<void **>(&d_work_), unit_size_ * lwork_) != cudaSuccess) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cusolver malloc work size fail";
|
||||
}
|
||||
|
||||
// 6. solve to lu factorization according to cuSolver api, outputs have been written to input's matrix.
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_, cusolverDnDgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr),
|
||||
"cusolver lu fail");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
|
||||
}
|
||||
// 7. copy results from written input's matrix to output's matrix.
|
||||
// if (cudaMemcpy(output_addr, input_addr, lda_ * m_ * unit_size_, cudaMemcpyDeviceToDevice) != cudaSuccess) {
|
||||
// MS_LOG(EXCEPTION) << "memcpy lu output fail.";
|
||||
// }
|
||||
MatrixCopy(input_addr, output_addr, lda_ * m_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if (d_work_) {
|
||||
cudaFree(d_work_);
|
||||
// 5. malloc device working space of getrf
|
||||
d_work_ = reinterpret_cast<T *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_));
|
||||
for (size_t batch = 0; batch < batch_; ++batch) {
|
||||
T *output_addr = batch_output_addr + batch * m_ * n_;
|
||||
int *permutation_addr = batch_permutation_addr + batch * k_ * k_;
|
||||
int *piv_output_addr = batch_piv_output_addr + batch * k_;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(dev_transpose_shape, host_transpose_shape, shape_2d * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"malloc input shape workspace failed");
|
||||
|
||||
CalTranspose(m_ * n_, output_addr, dev_transpose_shape, dev_transpose_axis, shape_2d, dev_transpose_work,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
// 6.lu factorization according to cuSolver api, outputs have been written to input's matrix.
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cusolverDnSgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr),
|
||||
"cusolver lu fail");
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
// 6.lu factorization according to cuSolver api, outputs have been written to input's matrix.
|
||||
CHECK_CUSOLVER_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cusolverDnDgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr),
|
||||
"cusolver lu fail");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
|
||||
}
|
||||
|
||||
size_t host_wk_transpose_shape[shape_2d] = {n_, m_};
|
||||
cudaMemcpyAsync(dev_transpose_shape, host_wk_transpose_shape, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalTranspose(m_ * n_, dev_transpose_work, dev_transpose_shape, dev_transpose_axis, shape_2d, output_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
std::vector<int> host_permuted(k_, 0);
|
||||
std::vector<int> host_pivots(k_, 0);
|
||||
std::vector<int> host_permutation(k_ * k_, 0);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(host_pivots.data(), piv_output_addr, sizeof(int) * k_,
|
||||
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync failed in LUGpuKernel::Launch copy pivots to host.");
|
||||
|
||||
// cal pivots && permutation major by row.
|
||||
for (size_t i = 0; i < k_; ++i) {
|
||||
host_pivots[i] -= 1;
|
||||
host_permuted[i] = i;
|
||||
}
|
||||
for (size_t i = 0; i < k_; ++i) {
|
||||
int tmp_value = host_permuted[i];
|
||||
host_permuted[i] = host_permuted[host_pivots[i]];
|
||||
host_permuted[host_pivots[i]] = tmp_value;
|
||||
}
|
||||
// gpu default is P.A = LU, so here is col swap.
|
||||
for (size_t i = 0; i < k_; ++i) {
|
||||
host_permutation[host_permuted[i] * k_ + i] = 1;
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(permutation_addr, host_permutation.data(), sizeof(int) * k_ * k_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync failed in LUGpuKernel::Launch copy permutation matrix.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(piv_output_addr, host_pivots.data(), sizeof(int) * k_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync failed in LUGpuKernel::Launch copy pivots array.");
|
||||
}
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work_);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -125,48 +176,65 @@ class LUGpuKernel : public GpuKernel {
|
|||
|
||||
private:
|
||||
bool InitInputSize(const std::vector<size_t> &in_shape) {
|
||||
if (in_shape.size() == kLuDefaultShape) {
|
||||
lu_row_ = in_shape.at(kDim0);
|
||||
lu_col_ = lu_row_;
|
||||
} else if (in_shape.size() == kLuNormalShape) {
|
||||
lu_row_ = in_shape.at(kDim0);
|
||||
lu_col_ = in_shape.at(kDim1);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input only should be 1 or 2";
|
||||
return false;
|
||||
constexpr size_t lu_min_dim = 1;
|
||||
constexpr size_t lu_max_dim = 3;
|
||||
if (in_shape.size() < lu_min_dim || in_shape.size() > lu_max_dim) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << "shape is " << in_shape.size() << " which is invalid.";
|
||||
}
|
||||
if (lu_row_ != lu_col_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of input should be square matrix";
|
||||
return false;
|
||||
if (in_shape.size() == lu_max_dim) {
|
||||
batch_ = in_shape.front();
|
||||
lu_row_ = in_shape.at(lu_min_dim);
|
||||
lu_col_ = in_shape.at(lu_max_dim - 1);
|
||||
} else {
|
||||
batch_ = 1;
|
||||
lu_row_ = in_shape.front();
|
||||
lu_col_ = in_shape.at(lu_min_dim);
|
||||
}
|
||||
// set matrix row or col to be lead dimension
|
||||
m_ = SizeToInt(lu_row_);
|
||||
n_ = SizeToInt(lu_col_);
|
||||
k_ = std::min(lu_row_, lu_col_);
|
||||
lda_ = m_;
|
||||
ldb_ = m_;
|
||||
ldb_ = n_;
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
size_t input_size = lda_ * m_ * unit_size_;
|
||||
size_t input_size = batch_ * lu_row_ * lu_col_ * unit_size_;
|
||||
input_size_list_.push_back(input_size);
|
||||
|
||||
size_t output_size = lda_ * m_ * unit_size_;
|
||||
size_t output_size = batch_ * lu_row_ * lu_col_ * unit_size_;
|
||||
|
||||
size_t output_piv_size = 0;
|
||||
size_t output_info_size = sizeof(int);
|
||||
if (pivot_on_) {
|
||||
output_piv_size = m_ * sizeof(int);
|
||||
output_piv_size = batch_ * k_ * sizeof(int);
|
||||
}
|
||||
size_t output_permutation_size = batch_ * k_ * k_ * sizeof(int);
|
||||
output_size_list_.resize(kDim3);
|
||||
output_size_list_[kDim0] = output_size;
|
||||
output_size_list_[kDim1] = output_piv_size;
|
||||
output_size_list_[kDim2] = output_info_size;
|
||||
output_size_list_[kDim2] = output_permutation_size;
|
||||
|
||||
// a device addr to place lu factor return code
|
||||
workspace_size_list_.push_back(sizeof(int));
|
||||
|
||||
// transpose 2d matrix scalar args workspace
|
||||
constexpr size_t shape_2d = 2;
|
||||
workspace_size_list_.push_back(shape_2d * sizeof(size_t));
|
||||
workspace_size_list_.push_back(shape_2d * sizeof(size_t));
|
||||
|
||||
// transpose workspace
|
||||
workspace_size_list_.push_back(m_ * n_ * unit_size_);
|
||||
}
|
||||
|
||||
size_t unit_size_{sizeof(T)};
|
||||
size_t batch_{1};
|
||||
size_t lu_row_{0};
|
||||
size_t lu_col_{0};
|
||||
size_t k_{0};
|
||||
size_t m_{0};
|
||||
size_t n_{0};
|
||||
size_t lda_{0};
|
||||
size_t ldb_{0};
|
||||
int lwork_{0};
|
||||
|
|
|
@ -514,7 +514,7 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
|
|||
[ 7.14285714e-01, 1.20000000e-01, -1.04000000e+00, 3.08000000e+00],
|
||||
[ 7.14285714e-01, -4.40000000e-01, -4.61538462e-01, 7.46153846e+00]])
|
||||
>>> piv
|
||||
Tensor(shape=[4], dtype=Int32, value= [2, 0, 3, 1])
|
||||
Tensor(shape=[4], dtype=Int32, value= [2, 2, 3, 3])
|
||||
"""
|
||||
if F.dtype(a) not in float_types:
|
||||
a = F.cast(a, mstype.float32)
|
||||
|
@ -592,9 +592,11 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|||
m_lu, _, p = msp_lu(a)
|
||||
m = a.shape[-2]
|
||||
n = a.shape[-1]
|
||||
if m > n:
|
||||
_raise_value_error("last two dimensions of LU decomposition must be row less or equal to col.")
|
||||
k = min(m, n)
|
||||
a_dtype = a.dtype
|
||||
l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype)
|
||||
l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=a_dtype)
|
||||
u = mnp.triu(m_lu)[:k, :]
|
||||
if permute_l:
|
||||
return mnp.dot(p, l), u
|
||||
|
@ -642,7 +644,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
|||
check_lu_shape(m_lu, b)
|
||||
# here permutation array has been calculated, just use it.
|
||||
# 2. calculate permutation
|
||||
permutation = lu_pivots_to_permutation(pivots, len(pivots))
|
||||
permutation = lu_pivots_to_permutation(pivots, pivots.size)
|
||||
# 3. rhs_vector
|
||||
rhs_vector = m_lu.ndim == b.ndim + 1
|
||||
x = lu_solve_core(m_lu, permutation, b, trans)
|
||||
|
|
|
@ -301,13 +301,16 @@ class LU(PrimitiveWithInfer):
|
|||
x_shape = list(x['shape'])
|
||||
x_dtype = x['dtype']
|
||||
ndim = len(x_shape)
|
||||
|
||||
if ndim in (1, 2):
|
||||
permutation_shape = (x_shape[0], x_shape[0])
|
||||
k_shape = min(x_shape[0], x_shape[1])
|
||||
permutation_shape = (k_shape, k_shape)
|
||||
pivots_shape = (1, k_shape)
|
||||
else:
|
||||
permutation_shape = (x_shape[0], x_shape[1], x_shape[1])
|
||||
k_shape = min(x_shape[1], x_shape[2])
|
||||
permutation_shape = (x_shape[0], k_shape, k_shape)
|
||||
pivots_shape = (x_shape[0], 1, k_shape)
|
||||
output = {
|
||||
'shape': (x_shape, permutation_shape[:-1], permutation_shape),
|
||||
'shape': (x_shape, pivots_shape, permutation_shape),
|
||||
'dtype': (x_dtype, mstype.int32, mstype.int32),
|
||||
'value': None
|
||||
}
|
||||
|
|
|
@ -217,9 +217,10 @@ def test_eigh_complex(n: int, dtype):
|
|||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(4, 4), (4, 5), (10, 5), (20, 20)])
|
||||
@pytest.mark.parametrize('shape', [(4, 4), (4, 5), (5, 10), (20, 20)])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_lu(shape: (int, int), dtype):
|
||||
"""
|
||||
|
@ -238,6 +239,36 @@ def test_lu(shape: (int, int), dtype):
|
|||
assert onp.allclose(m_u.asnumpy(), s_u, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5)])
|
||||
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
|
||||
def test_batch_lu(shape: (int, int, int), dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
|
||||
Expectation: the result match to scipy
|
||||
"""
|
||||
b_a = create_random_rank_matrix(shape, dtype)
|
||||
b_s_p = list()
|
||||
b_s_l = list()
|
||||
b_s_u = list()
|
||||
for a in b_a:
|
||||
s_p, s_l, s_u = osp.linalg.lu(a)
|
||||
b_s_p.append(s_p)
|
||||
b_s_l.append(s_l)
|
||||
b_s_u.append(s_u)
|
||||
tensor_b_a = Tensor(onp.array(b_a))
|
||||
b_m_p, b_m_l, b_m_u = msp.linalg.lu(tensor_b_a)
|
||||
rtol = 1.e-5
|
||||
atol = 1.e-5
|
||||
assert onp.allclose(b_m_p.asnumpy(), b_s_p, rtol=rtol, atol=atol)
|
||||
assert onp.allclose(b_m_l.asnumpy(), b_s_l, rtol=rtol, atol=atol)
|
||||
assert onp.allclose(b_m_u.asnumpy(), b_s_u, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue