!28988 fix lu batched for gpu && cpu backend

Merge pull request !28988 from zhuzhongrui/r1.6
This commit is contained in:
i-robot 2022-01-13 08:03:46 +00:00 committed by Gitee
commit b73f77a4da
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 275 additions and 160 deletions

View File

@ -30,7 +30,6 @@ constexpr size_t kLUOutputsNum = 3;
constexpr size_t kLuIndex = 0;
constexpr size_t kPivotsIndex = 1;
constexpr size_t kPermutationIndex = 2;
constexpr size_t kLUDefaultShape = 1;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
constexpr int kZeroThreshold = INT32_MIN;
@ -38,16 +37,20 @@ constexpr int kZeroThreshold = INT32_MIN;
template <typename T>
void LUCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
if (shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << "shape is invalid.";
constexpr size_t lu_min_dim = 1;
constexpr size_t lu_max_dim = 3;
if (shape.size() < lu_min_dim || shape.size() > lu_max_dim) {
MS_LOG_EXCEPTION << kernel_name_ << "shape is " << shape.size() << " which is invalid.";
}
if (shape.size() == kLUDefaultShape) {
*row = shape.front();
*col = 1;
} else {
*row = shape.at(shape.size() - kRowIndex);
*col = shape.at(shape.size() - kColIndex);
if (shape.size() == lu_max_dim) {
batch_ = shape.front();
*row = shape.at(lu_min_dim);
*col = shape.at(lu_max_dim - 1);
return;
}
batch_ = 1;
*row = shape.front();
*col = shape.at(lu_min_dim);
}
template <typename T>
@ -84,7 +87,8 @@ T LUCPUKernel<T>::GetPermutatedValue(const T *lu_value, const std::vector<int> &
}
template <typename T>
bool LUCPUKernel<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, size_t k, size_t rows) {
bool LUCPUKernel<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, int *pivots, size_t k,
size_t rows) {
T max_major_value = static_cast<T>(kZeroThreshold);
int max_major_index = 0;
for (size_t i = k; i < rows; ++i) {
@ -98,7 +102,7 @@ bool LUCPUKernel<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *const
int per_k = per_value->at(k);
(*per_value)[k] = per_value->at(max_major_index);
(*per_value)[max_major_index] = per_k;
pivots_[k] = max_major_index;
pivots[k] = max_major_index;
return max_major_value != static_cast<T>(kZeroThreshold);
}
@ -114,102 +118,108 @@ bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
// input matrix of (m,n) PA = LU
T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
T *lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
// pivots permutation value
pivots_ = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
// permutation matrix value
int *permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr);
T *batch_a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
T *batch_lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
batch_pivots_ = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
int *batch_permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr);
T *lu_ori_wk = reinterpret_cast<T *>(workspace[kLuIndex]->addr);
T *lu_trans_wk = reinterpret_cast<T *>(workspace[kPivotsIndex]->addr);
// init pivots
std::vector<int> per_value(pivots_row_, 0);
for (size_t i = 0; i < pivots_row_; ++i) {
per_value[i] = i;
}
// 1. memcpy input to output, do full lu inplace.
(void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T));
int s = std::min(a_row_, a_col_);
// 2. do lu decompose inplace
for (int k = 0; k < s; ++k) {
// 2.1 choose major element of current col if return false means current col elements are all zero, just continue.
if (!UpdateMajorPermutation(lu_value, &per_value, k, lu_row_)) {
continue;
}
// 2.2 major element x --> (1/x), get inplace origin lu matrix value.
T value = static_cast<T>(1.0 / GetPermutatedValue(lu_value, per_value, k, k));
// 2.3 change major col values
for (size_t i = k + 1; i < lu_row_; ++i) {
T y = static_cast<T>(GetPermutatedValue(lu_value, per_value, i, k) * value);
// set inplace new lu matrix value.
SetPermutatedValue(lu_value, per_value, i, k, y);
for (size_t batch = 0; batch < batch_; ++batch) {
T *a_value = batch_a_value + batch * a_row_ * a_col_;
T *lu_value = batch_lu_value + batch * lu_row_ * lu_col_;
// pivots permutation value
int *pivots = batch_pivots_ + batch * pivots_row_ * pivots_col_;
// permutation matrix value
int *permutation_value = batch_permutation_value + batch * permutation_row_ * permutation_col_;
// init pivots
std::vector<int> per_value(pivots_col_, 0);
for (size_t i = 0; i < pivots_col_; ++i) {
per_value[i] = i;
}
// 1. memcpy input to output, do full lu inplace.
(void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T));
// 2.4 Gauss elimination core
for (size_t i = k + 1; i < lu_row_; ++i) {
for (size_t j = k + 1; j < lu_col_; ++j) {
T y =
static_cast<T>(GetPermutatedValue(lu_value, per_value, i, j) -
GetPermutatedValue(lu_value, per_value, i, k) * GetPermutatedValue(lu_value, per_value, k, j));
SetPermutatedValue(lu_value, per_value, i, j, y);
int s = std::min(a_row_, a_col_);
// 2. do lu decompose inplace
for (int k = 0; k < s; ++k) {
// 2.1 choose major element of current col if return false means current col elements are all zero, just continue.
if (!UpdateMajorPermutation(lu_value, &per_value, pivots, k, lu_row_)) {
continue;
}
// 2.2 major element x --> (1/x), get inplace origin lu matrix value.
T value = static_cast<T>(1.0 / GetPermutatedValue(lu_value, per_value, k, k));
// 2.3 change major col values
for (size_t i = k + 1; i < lu_row_; ++i) {
T y = static_cast<T>(GetPermutatedValue(lu_value, per_value, i, k) * value);
// set inplace new lu matrix value.
SetPermutatedValue(lu_value, per_value, i, k, y);
}
// 2.4 Gauss elimination core
for (size_t i = k + 1; i < lu_row_; ++i) {
for (size_t j = k + 1; j < lu_col_; ++j) {
T y = static_cast<T>(GetPermutatedValue(lu_value, per_value, i, j) -
GetPermutatedValue(lu_value, per_value, i, k) *
GetPermutatedValue(lu_value, per_value, k, j));
SetPermutatedValue(lu_value, per_value, i, j, y);
}
}
}
}
// 3. calculate final lu by permutation list
std::unordered_map<int, std::pair<int, bool>> pivots_map;
for (int i = 0; i < static_cast<int>(lu_row_); ++i) {
pivots_map[per_value[i]] = {i, false};
}
int pivots_count = 0;
for (const auto &pivot : pivots_map) {
pivots_count++;
int key = pivot.first;
int index = pivot.second.first;
bool is_visited = pivot.second.second;
if (is_visited || index == (pivots_count - 1)) {
continue;
// 3. calculate final lu by permutation list
std::unordered_map<int, std::pair<int, bool>> pivots_map;
for (int i = 0; i < static_cast<int>(lu_row_); ++i) {
pivots_map[per_value[i]] = {i, false};
}
int pivots_count = 0;
for (const auto &pivot : pivots_map) {
pivots_count++;
int key = pivot.first;
int index = pivot.second.first;
bool is_visited = pivot.second.second;
if (is_visited || index == (pivots_count - 1)) {
continue;
}
T *lu_ori_row = lu_value + index * lu_col_;
T *lu_trans_row = lu_value + key * lu_col_;
// copy ori data to trans lu
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
// copy new data to ori data ptr
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_trans_row, lu_col_ * sizeof(T));
// update pivot map
pivots_map[key] = {index, true};
// put ori data which stored in workspace to mapped new place
is_visited = pivots_map[index].second;
while (!is_visited) {
key = index;
index = pivots_map[key].first;
is_visited = pivots_map[key].second;
lu_ori_row = lu_value + index * lu_col_;
T *tmp_wk = lu_trans_wk;
lu_trans_wk = lu_ori_wk;
lu_ori_wk = tmp_wk;
// copy new ori data to trans workspace
T *lu_ori_row = lu_value + index * lu_col_;
T *lu_trans_row = lu_value + key * lu_col_;
// copy ori data to trans lu
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
// copy new data to ori data place
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_ori_wk, lu_col_ * sizeof(T));
// copy new data to ori data ptr
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_trans_row, lu_col_ * sizeof(T));
// update pivot map
pivots_map[key] = {index, true};
// put ori data which stored in workspace to mapped new place
is_visited = pivots_map[index].second;
while (!is_visited) {
key = index;
index = pivots_map[key].first;
is_visited = pivots_map[key].second;
lu_ori_row = lu_value + index * lu_col_;
T *tmp_wk = lu_trans_wk;
lu_trans_wk = lu_ori_wk;
lu_ori_wk = tmp_wk;
// copy new ori data to trans workspace
(void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T));
// copy new data to ori data place
(void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_ori_wk, lu_col_ * sizeof(T));
pivots_map[key] = {index, true};
}
}
// 4. calculate final permutation matrix
// for PA = LU get: base + row * permutation_row_ + col
// for A = PLU get: base + col * permutation_row_ + row
// here, we do A = PLU which is same as scipy.
size_t count = permutation_col_ * permutation_row_ * sizeof(int);
(void)memset_s(reinterpret_cast<void *>(permutation_value), count, 0, count);
for (size_t i = 0; i < pivots_col_; ++i) {
int position = per_value[i];
int *per_addr = permutation_value + position * permutation_row_ + i;
*per_addr = 1;
}
}
// 4. calculate final permutation matrix
// for PA = LU get: base + row * permutation_row_ + col
// for A = PLU get: base + col * permutation_row_ + row
// here, we do A = PLU which is same as scipy.
size_t count = permutation_col_ * permutation_row_ * sizeof(int);
(void)memset_s(reinterpret_cast<void *>(permutation_value), count, 0, count);
for (size_t i = 0; i < pivots_row_; ++i) {
int position = per_value[i];
int *per_addr = permutation_value + position * permutation_row_ + i;
*per_addr = 1;
}
return true;
}
} // namespace kernel

View File

@ -36,8 +36,9 @@ class LUCPUKernel : public CPUKernel {
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitInputOutputSize(const CNodePtr &kernel_node) override;
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, size_t k, size_t rows);
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, int *pivots, size_t k, size_t rows);
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
size_t batch_{1};
size_t a_row_{1};
size_t a_col_{1};
size_t lu_row_{1};
@ -47,7 +48,7 @@ class LUCPUKernel : public CPUKernel {
size_t permutation_row_{1};
size_t permutation_col_{1};
TypeId dtype_{kNumberTypeFloat32};
int *pivots_{nullptr};
int *batch_pivots_{nullptr};
};
MS_REG_CPU_KERNEL_T(LU,

View File

@ -26,16 +26,10 @@
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "utils/convert_utils.h"
#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t kLuInputsNum = 1;
constexpr size_t kInputIndex = 0;
constexpr size_t kLuOutputsNum = 1;
constexpr size_t kOutputIndex = 0;
constexpr size_t kLuDefaultShape = 1;
constexpr size_t kLuNormalShape = 2;
template <typename T>
class LUGpuKernel : public GpuKernel {
@ -53,54 +47,111 @@ class LUGpuKernel : public GpuKernel {
}
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cusolverDnSetStream failed");
auto input_addr = GetDeviceAddress<T>(inputs, kDim0);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
int *piv_output_addr = nullptr;
T *batch_input_addr = GetDeviceAddress<T>(inputs, kDim0);
T *batch_output_addr = GetDeviceAddress<T>(outputs, kDim0);
int *batch_piv_output_addr = nullptr;
if (pivot_on_) {
piv_output_addr = GetDeviceAddress<int>(outputs, kDim1);
batch_piv_output_addr = GetDeviceAddress<int>(outputs, kDim1);
}
int *batch_permutation_addr = GetDeviceAddress<int>(outputs, kDim2);
int *info_output_addr = GetDeviceAddress<int>(workspace, kDim0);
auto info_output_addr = GetDeviceAddress<int>(outputs, kDim2);
size_t *dev_transpose_shape = GetDeviceAddress<size_t>(workspace, kDim1);
size_t *dev_transpose_axis = GetDeviceAddress<size_t>(workspace, kDim2);
constexpr size_t shape_2d = 2;
size_t host_transpose_shape[shape_2d] = {m_, n_};
size_t host_transpose_axis[shape_2d] = {1, 0};
T *dev_transpose_work = GetDeviceAddress<T>(workspace, kDim3);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(dev_transpose_axis, host_transpose_axis, shape_2d * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"malloc input shape workspace failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_ * m_ * n_ * unit_size_,
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernel::Launch.");
// 4. query working space of getrf
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
cusolverDnSgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_),
cusolverDnSgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
"cusolver query lu work size fail");
if (cudaMalloc(reinterpret_cast<void **>(&d_work_), unit_size_ * lwork_) != cudaSuccess) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cusolver malloc work size fail";
}
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnSgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_,
cusolverDnDgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_),
cusolverDnDgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_),
"cusolver query lu work size fail");
// 5. malloc device working space of getrf
if (cudaMalloc(reinterpret_cast<void **>(&d_work_), unit_size_ * lwork_) != cudaSuccess) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cusolver malloc work size fail";
}
// 6. solve to lu factorization according to cuSolver api, outputs have been written to input's matrix.
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_, cusolverDnDgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
}
// 7. copy results from written input's matrix to output's matrix.
// if (cudaMemcpy(output_addr, input_addr, lda_ * m_ * unit_size_, cudaMemcpyDeviceToDevice) != cudaSuccess) {
// MS_LOG(EXCEPTION) << "memcpy lu output fail.";
// }
MatrixCopy(input_addr, output_addr, lda_ * m_, reinterpret_cast<cudaStream_t>(stream_ptr));
if (d_work_) {
cudaFree(d_work_);
// 5. malloc device working space of getrf
d_work_ = reinterpret_cast<T *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_));
for (size_t batch = 0; batch < batch_; ++batch) {
T *output_addr = batch_output_addr + batch * m_ * n_;
int *permutation_addr = batch_permutation_addr + batch * k_ * k_;
int *piv_output_addr = batch_piv_output_addr + batch * k_;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(dev_transpose_shape, host_transpose_shape, shape_2d * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"malloc input shape workspace failed");
CalTranspose(m_ * n_, output_addr, dev_transpose_shape, dev_transpose_axis, shape_2d, dev_transpose_work,
reinterpret_cast<cudaStream_t>(stream_ptr));
// 6.lu factorization according to cuSolver api, outputs have been written to input's matrix.
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_,
cusolverDnSgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else if constexpr (std::is_same_v<T, double>) {
// 6.lu factorization according to cuSolver api, outputs have been written to input's matrix.
CHECK_CUSOLVER_RET_WITH_EXCEPT(
kernel_node_,
cusolverDnDgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr),
"cusolver lu fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
}
size_t host_wk_transpose_shape[shape_2d] = {n_, m_};
cudaMemcpyAsync(dev_transpose_shape, host_wk_transpose_shape, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalTranspose(m_ * n_, dev_transpose_work, dev_transpose_shape, dev_transpose_axis, shape_2d, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
std::vector<int> host_permuted(k_, 0);
std::vector<int> host_pivots(k_, 0);
std::vector<int> host_permutation(k_ * k_, 0);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(host_pivots.data(), piv_output_addr, sizeof(int) * k_,
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernel::Launch copy pivots to host.");
// cal pivots && permutation major by row.
for (size_t i = 0; i < k_; ++i) {
host_pivots[i] -= 1;
host_permuted[i] = i;
}
for (size_t i = 0; i < k_; ++i) {
int tmp_value = host_permuted[i];
host_permuted[i] = host_permuted[host_pivots[i]];
host_permuted[host_pivots[i]] = tmp_value;
}
// gpu default is P.A = LU, so here is col swap.
for (size_t i = 0; i < k_; ++i) {
host_permutation[host_permuted[i] * k_ + i] = 1;
}
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(permutation_addr, host_permutation.data(), sizeof(int) * k_ * k_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernel::Launch copy permutation matrix.");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(piv_output_addr, host_pivots.data(), sizeof(int) * k_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernel::Launch copy pivots array.");
}
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work_);
return true;
}
@ -125,48 +176,65 @@ class LUGpuKernel : public GpuKernel {
private:
bool InitInputSize(const std::vector<size_t> &in_shape) {
if (in_shape.size() == kLuDefaultShape) {
lu_row_ = in_shape.at(kDim0);
lu_col_ = lu_row_;
} else if (in_shape.size() == kLuNormalShape) {
lu_row_ = in_shape.at(kDim0);
lu_col_ = in_shape.at(kDim1);
} else {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input only should be 1 or 2";
return false;
constexpr size_t lu_min_dim = 1;
constexpr size_t lu_max_dim = 3;
if (in_shape.size() < lu_min_dim || in_shape.size() > lu_max_dim) {
MS_LOG_EXCEPTION << kernel_name_ << "shape is " << in_shape.size() << " which is invalid.";
}
if (lu_row_ != lu_col_) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of input should be square matrix";
return false;
if (in_shape.size() == lu_max_dim) {
batch_ = in_shape.front();
lu_row_ = in_shape.at(lu_min_dim);
lu_col_ = in_shape.at(lu_max_dim - 1);
} else {
batch_ = 1;
lu_row_ = in_shape.front();
lu_col_ = in_shape.at(lu_min_dim);
}
// set matrix row or col to be lead dimension
m_ = SizeToInt(lu_row_);
n_ = SizeToInt(lu_col_);
k_ = std::min(lu_row_, lu_col_);
lda_ = m_;
ldb_ = m_;
ldb_ = n_;
InitSizeLists();
return true;
}
void InitSizeLists() override {
size_t input_size = lda_ * m_ * unit_size_;
size_t input_size = batch_ * lu_row_ * lu_col_ * unit_size_;
input_size_list_.push_back(input_size);
size_t output_size = lda_ * m_ * unit_size_;
size_t output_size = batch_ * lu_row_ * lu_col_ * unit_size_;
size_t output_piv_size = 0;
size_t output_info_size = sizeof(int);
if (pivot_on_) {
output_piv_size = m_ * sizeof(int);
output_piv_size = batch_ * k_ * sizeof(int);
}
size_t output_permutation_size = batch_ * k_ * k_ * sizeof(int);
output_size_list_.resize(kDim3);
output_size_list_[kDim0] = output_size;
output_size_list_[kDim1] = output_piv_size;
output_size_list_[kDim2] = output_info_size;
output_size_list_[kDim2] = output_permutation_size;
// a device addr to place lu factor return code
workspace_size_list_.push_back(sizeof(int));
// transpose 2d matrix scalar args workspace
constexpr size_t shape_2d = 2;
workspace_size_list_.push_back(shape_2d * sizeof(size_t));
workspace_size_list_.push_back(shape_2d * sizeof(size_t));
// transpose workspace
workspace_size_list_.push_back(m_ * n_ * unit_size_);
}
size_t unit_size_{sizeof(T)};
size_t batch_{1};
size_t lu_row_{0};
size_t lu_col_{0};
size_t k_{0};
size_t m_{0};
size_t n_{0};
size_t lda_{0};
size_t ldb_{0};
int lwork_{0};

View File

@ -514,7 +514,7 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
[ 7.14285714e-01, 1.20000000e-01, -1.04000000e+00, 3.08000000e+00],
[ 7.14285714e-01, -4.40000000e-01, -4.61538462e-01, 7.46153846e+00]])
>>> piv
Tensor(shape=[4], dtype=Int32, value= [2, 0, 3, 1])
Tensor(shape=[4], dtype=Int32, value= [2, 2, 3, 3])
"""
if F.dtype(a) not in float_types:
a = F.cast(a, mstype.float32)
@ -592,9 +592,11 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
m_lu, _, p = msp_lu(a)
m = a.shape[-2]
n = a.shape[-1]
if m > n:
_raise_value_error("last two dimensions of LU decomposition must be row less or equal to col.")
k = min(m, n)
a_dtype = a.dtype
l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype)
l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=a_dtype)
u = mnp.triu(m_lu)[:k, :]
if permute_l:
return mnp.dot(p, l), u
@ -642,7 +644,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
check_lu_shape(m_lu, b)
# here permutation array has been calculated, just use it.
# 2. calculate permutation
permutation = lu_pivots_to_permutation(pivots, len(pivots))
permutation = lu_pivots_to_permutation(pivots, pivots.size)
# 3. rhs_vector
rhs_vector = m_lu.ndim == b.ndim + 1
x = lu_solve_core(m_lu, permutation, b, trans)

View File

@ -301,13 +301,16 @@ class LU(PrimitiveWithInfer):
x_shape = list(x['shape'])
x_dtype = x['dtype']
ndim = len(x_shape)
if ndim in (1, 2):
permutation_shape = (x_shape[0], x_shape[0])
k_shape = min(x_shape[0], x_shape[1])
permutation_shape = (k_shape, k_shape)
pivots_shape = (1, k_shape)
else:
permutation_shape = (x_shape[0], x_shape[1], x_shape[1])
k_shape = min(x_shape[1], x_shape[2])
permutation_shape = (x_shape[0], k_shape, k_shape)
pivots_shape = (x_shape[0], 1, k_shape)
output = {
'shape': (x_shape, permutation_shape[:-1], permutation_shape),
'shape': (x_shape, pivots_shape, permutation_shape),
'dtype': (x_dtype, mstype.int32, mstype.int32),
'value': None
}

View File

@ -217,9 +217,10 @@ def test_eigh_complex(n: int, dtype):
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(4, 4), (4, 5), (10, 5), (20, 20)])
@pytest.mark.parametrize('shape', [(4, 4), (4, 5), (5, 10), (20, 20)])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_lu(shape: (int, int), dtype):
"""
@ -238,6 +239,36 @@ def test_lu(shape: (int, int), dtype):
assert onp.allclose(m_u.asnumpy(), s_u, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5)])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_batch_lu(shape: (int, int, int), dtype):
"""
Feature: ALL To ALL
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
b_a = create_random_rank_matrix(shape, dtype)
b_s_p = list()
b_s_l = list()
b_s_u = list()
for a in b_a:
s_p, s_l, s_u = osp.linalg.lu(a)
b_s_p.append(s_p)
b_s_l.append(s_l)
b_s_u.append(s_u)
tensor_b_a = Tensor(onp.array(b_a))
b_m_p, b_m_l, b_m_u = msp.linalg.lu(tensor_b_a)
rtol = 1.e-5
atol = 1.e-5
assert onp.allclose(b_m_p.asnumpy(), b_s_p, rtol=rtol, atol=atol)
assert onp.allclose(b_m_l.asnumpy(), b_s_l, rtol=rtol, atol=atol)
assert onp.allclose(b_m_u.asnumpy(), b_s_u, rtol=rtol, atol=atol)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard