From 15d35023e50eb95b7fb050528271013be1cae4a7 Mon Sep 17 00:00:00 2001 From: z00512249 Date: Thu, 13 Jan 2022 11:21:49 +0800 Subject: [PATCH] fix lu batched for gpu && cpu backend --- .../cpu/eigen/lu_cpu_kernel.cc | 196 +++++++++--------- .../kernel_compiler/cpu/eigen/lu_cpu_kernel.h | 5 +- .../kernel_compiler/gpu/math/lu_gpu_kernel.h | 182 +++++++++++----- mindspore/python/mindspore/scipy/linalg.py | 8 +- mindspore/python/mindspore/scipy/ops.py | 11 +- tests/st/scipy_st/test_linalg.py | 33 ++- 6 files changed, 275 insertions(+), 160 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc index 749790790dd..50192376dd7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc @@ -30,7 +30,6 @@ constexpr size_t kLUOutputsNum = 3; constexpr size_t kLuIndex = 0; constexpr size_t kPivotsIndex = 1; constexpr size_t kPermutationIndex = 2; -constexpr size_t kLUDefaultShape = 1; constexpr size_t kRowIndex = 2; constexpr size_t kColIndex = 1; constexpr int kZeroThreshold = INT32_MIN; @@ -38,16 +37,20 @@ constexpr int kZeroThreshold = INT32_MIN; template void LUCPUKernel::InitMatrixInfo(const std::vector &shape, size_t *row, size_t *col) { - if (shape.empty()) { - MS_LOG_EXCEPTION << kernel_name_ << "shape is invalid."; + constexpr size_t lu_min_dim = 1; + constexpr size_t lu_max_dim = 3; + if (shape.size() < lu_min_dim || shape.size() > lu_max_dim) { + MS_LOG_EXCEPTION << kernel_name_ << "shape is " << shape.size() << " which is invalid."; } - if (shape.size() == kLUDefaultShape) { - *row = shape.front(); - *col = 1; - } else { - *row = shape.at(shape.size() - kRowIndex); - *col = shape.at(shape.size() - kColIndex); + if (shape.size() == lu_max_dim) { + batch_ = shape.front(); + *row = shape.at(lu_min_dim); + *col = shape.at(lu_max_dim - 1); + return; } + batch_ = 1; + *row = shape.front(); + *col = shape.at(lu_min_dim); } template @@ -84,7 +87,8 @@ T LUCPUKernel::GetPermutatedValue(const T *lu_value, const std::vector & } template -bool LUCPUKernel::UpdateMajorPermutation(T *lu_value, std::vector *const per_value, size_t k, size_t rows) { +bool LUCPUKernel::UpdateMajorPermutation(T *lu_value, std::vector *const per_value, int *pivots, size_t k, + size_t rows) { T max_major_value = static_cast(kZeroThreshold); int max_major_index = 0; for (size_t i = k; i < rows; ++i) { @@ -98,7 +102,7 @@ bool LUCPUKernel::UpdateMajorPermutation(T *lu_value, std::vector *const int per_k = per_value->at(k); (*per_value)[k] = per_value->at(max_major_index); (*per_value)[max_major_index] = per_k; - pivots_[k] = max_major_index; + pivots[k] = max_major_index; return max_major_value != static_cast(kZeroThreshold); } @@ -114,102 +118,108 @@ bool LUCPUKernel::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) { // input matrix of (m,n) PA = LU - T *a_value = reinterpret_cast(inputs[kLUaIndex]->addr); - T *lu_value = reinterpret_cast(outputs[kLuIndex]->addr); - // pivots permutation value - pivots_ = reinterpret_cast(outputs[kPivotsIndex]->addr); - // permutation matrix value - int *permutation_value = reinterpret_cast(outputs[kPermutationIndex]->addr); + T *batch_a_value = reinterpret_cast(inputs[kLUaIndex]->addr); + T *batch_lu_value = reinterpret_cast(outputs[kLuIndex]->addr); + batch_pivots_ = reinterpret_cast(outputs[kPivotsIndex]->addr); + int *batch_permutation_value = reinterpret_cast(outputs[kPermutationIndex]->addr); T *lu_ori_wk = reinterpret_cast(workspace[kLuIndex]->addr); T *lu_trans_wk = reinterpret_cast(workspace[kPivotsIndex]->addr); - // init pivots - std::vector per_value(pivots_row_, 0); - for (size_t i = 0; i < pivots_row_; ++i) { - per_value[i] = i; - } - - // 1. memcpy input to output, do full lu inplace. - (void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T)); - - int s = std::min(a_row_, a_col_); - // 2. do lu decompose inplace - for (int k = 0; k < s; ++k) { - // 2.1 choose major element of current col if return false means current col elements are all zero, just continue. - if (!UpdateMajorPermutation(lu_value, &per_value, k, lu_row_)) { - continue; - } - // 2.2 major element x --> (1/x), get inplace origin lu matrix value. - T value = static_cast(1.0 / GetPermutatedValue(lu_value, per_value, k, k)); - // 2.3 change major col values - for (size_t i = k + 1; i < lu_row_; ++i) { - T y = static_cast(GetPermutatedValue(lu_value, per_value, i, k) * value); - // set inplace new lu matrix value. - SetPermutatedValue(lu_value, per_value, i, k, y); + for (size_t batch = 0; batch < batch_; ++batch) { + T *a_value = batch_a_value + batch * a_row_ * a_col_; + T *lu_value = batch_lu_value + batch * lu_row_ * lu_col_; + // pivots permutation value + int *pivots = batch_pivots_ + batch * pivots_row_ * pivots_col_; + // permutation matrix value + int *permutation_value = batch_permutation_value + batch * permutation_row_ * permutation_col_; + // init pivots + std::vector per_value(pivots_col_, 0); + for (size_t i = 0; i < pivots_col_; ++i) { + per_value[i] = i; } + // 1. memcpy input to output, do full lu inplace. + (void)memcpy_s(lu_value, lu_row_ * lu_col_ * sizeof(T), a_value, a_row_ * a_col_ * sizeof(T)); - // 2.4 Gauss elimination core - for (size_t i = k + 1; i < lu_row_; ++i) { - for (size_t j = k + 1; j < lu_col_; ++j) { - T y = - static_cast(GetPermutatedValue(lu_value, per_value, i, j) - - GetPermutatedValue(lu_value, per_value, i, k) * GetPermutatedValue(lu_value, per_value, k, j)); - SetPermutatedValue(lu_value, per_value, i, j, y); + int s = std::min(a_row_, a_col_); + // 2. do lu decompose inplace + for (int k = 0; k < s; ++k) { + // 2.1 choose major element of current col if return false means current col elements are all zero, just continue. + if (!UpdateMajorPermutation(lu_value, &per_value, pivots, k, lu_row_)) { + continue; + } + // 2.2 major element x --> (1/x), get inplace origin lu matrix value. + T value = static_cast(1.0 / GetPermutatedValue(lu_value, per_value, k, k)); + // 2.3 change major col values + for (size_t i = k + 1; i < lu_row_; ++i) { + T y = static_cast(GetPermutatedValue(lu_value, per_value, i, k) * value); + // set inplace new lu matrix value. + SetPermutatedValue(lu_value, per_value, i, k, y); + } + + // 2.4 Gauss elimination core + for (size_t i = k + 1; i < lu_row_; ++i) { + for (size_t j = k + 1; j < lu_col_; ++j) { + T y = static_cast(GetPermutatedValue(lu_value, per_value, i, j) - + GetPermutatedValue(lu_value, per_value, i, k) * + GetPermutatedValue(lu_value, per_value, k, j)); + SetPermutatedValue(lu_value, per_value, i, j, y); + } } } - } - // 3. calculate final lu by permutation list - std::unordered_map> pivots_map; - for (int i = 0; i < static_cast(lu_row_); ++i) { - pivots_map[per_value[i]] = {i, false}; - } - int pivots_count = 0; - for (const auto &pivot : pivots_map) { - pivots_count++; - int key = pivot.first; - int index = pivot.second.first; - bool is_visited = pivot.second.second; - if (is_visited || index == (pivots_count - 1)) { - continue; + // 3. calculate final lu by permutation list + std::unordered_map> pivots_map; + for (int i = 0; i < static_cast(lu_row_); ++i) { + pivots_map[per_value[i]] = {i, false}; } + int pivots_count = 0; + for (const auto &pivot : pivots_map) { + pivots_count++; + int key = pivot.first; + int index = pivot.second.first; + bool is_visited = pivot.second.second; + if (is_visited || index == (pivots_count - 1)) { + continue; + } - T *lu_ori_row = lu_value + index * lu_col_; - T *lu_trans_row = lu_value + key * lu_col_; - // copy ori data to trans lu - (void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T)); - // copy new data to ori data ptr - (void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_trans_row, lu_col_ * sizeof(T)); - // update pivot map - pivots_map[key] = {index, true}; - // put ori data which stored in workspace to mapped new place - is_visited = pivots_map[index].second; - while (!is_visited) { - key = index; - index = pivots_map[key].first; - is_visited = pivots_map[key].second; - lu_ori_row = lu_value + index * lu_col_; - T *tmp_wk = lu_trans_wk; - lu_trans_wk = lu_ori_wk; - lu_ori_wk = tmp_wk; - // copy new ori data to trans workspace + T *lu_ori_row = lu_value + index * lu_col_; + T *lu_trans_row = lu_value + key * lu_col_; + // copy ori data to trans lu (void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T)); - // copy new data to ori data place - (void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_ori_wk, lu_col_ * sizeof(T)); + // copy new data to ori data ptr + (void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_trans_row, lu_col_ * sizeof(T)); + // update pivot map pivots_map[key] = {index, true}; + // put ori data which stored in workspace to mapped new place + is_visited = pivots_map[index].second; + while (!is_visited) { + key = index; + index = pivots_map[key].first; + is_visited = pivots_map[key].second; + lu_ori_row = lu_value + index * lu_col_; + T *tmp_wk = lu_trans_wk; + lu_trans_wk = lu_ori_wk; + lu_ori_wk = tmp_wk; + // copy new ori data to trans workspace + (void)memcpy_s(lu_trans_wk, lu_col_ * sizeof(T), lu_ori_row, lu_col_ * sizeof(T)); + // copy new data to ori data place + (void)memcpy_s(lu_ori_row, lu_col_ * sizeof(T), lu_ori_wk, lu_col_ * sizeof(T)); + pivots_map[key] = {index, true}; + } + } + + // 4. calculate final permutation matrix + // for PA = LU get: base + row * permutation_row_ + col + // for A = PLU get: base + col * permutation_row_ + row + // here, we do A = PLU which is same as scipy. + size_t count = permutation_col_ * permutation_row_ * sizeof(int); + (void)memset_s(reinterpret_cast(permutation_value), count, 0, count); + for (size_t i = 0; i < pivots_col_; ++i) { + int position = per_value[i]; + int *per_addr = permutation_value + position * permutation_row_ + i; + *per_addr = 1; } } - // 4. calculate final permutation matrix - // for PA = LU get: base + row * permutation_row_ + col - // for A = PLU get: base + col * permutation_row_ + row - // here, we do A = PLU which is same as scipy. - size_t count = permutation_col_ * permutation_row_ * sizeof(int); - (void)memset_s(reinterpret_cast(permutation_value), count, 0, count); - for (size_t i = 0; i < pivots_row_; ++i) { - int position = per_value[i]; - int *per_addr = permutation_value + position * permutation_row_ + i; - *per_addr = 1; - } return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h index 441ce6d52f1..43ecdadd8f5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h @@ -36,8 +36,9 @@ class LUCPUKernel : public CPUKernel { void InitMatrixInfo(const std::vector &shape, size_t *row, size_t *col); void InitInputOutputSize(const CNodePtr &kernel_node) override; T GetPermutatedValue(const T *lu_value, const std::vector &per_value, size_t i, size_t j); - bool UpdateMajorPermutation(T *lu_value, std::vector *const per_value, size_t k, size_t rows); + bool UpdateMajorPermutation(T *lu_value, std::vector *const per_value, int *pivots, size_t k, size_t rows); void SetPermutatedValue(T *lu_value, const std::vector &per_value, size_t i, size_t j, const T &value); + size_t batch_{1}; size_t a_row_{1}; size_t a_col_{1}; size_t lu_row_{1}; @@ -47,7 +48,7 @@ class LUCPUKernel : public CPUKernel { size_t permutation_row_{1}; size_t permutation_col_{1}; TypeId dtype_{kNumberTypeFloat32}; - int *pivots_{nullptr}; + int *batch_pivots_{nullptr}; }; MS_REG_CPU_KERNEL_T(LU, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h index 04d2e8b7214..d48ce13e5a9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h @@ -26,16 +26,10 @@ #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/kernel_constants.h" #include "utils/convert_utils.h" -#include "backend/kernel_compiler/gpu/cuda_impl/triangle_matrix_copy_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" namespace mindspore { namespace kernel { -constexpr size_t kLuInputsNum = 1; -constexpr size_t kInputIndex = 0; -constexpr size_t kLuOutputsNum = 1; -constexpr size_t kOutputIndex = 0; -constexpr size_t kLuDefaultShape = 1; -constexpr size_t kLuNormalShape = 2; template class LUGpuKernel : public GpuKernel { @@ -53,54 +47,111 @@ class LUGpuKernel : public GpuKernel { } CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast(stream_ptr)), "cusolverDnSetStream failed"); - auto input_addr = GetDeviceAddress(inputs, kDim0); - auto output_addr = GetDeviceAddress(outputs, kDim0); - int *piv_output_addr = nullptr; + T *batch_input_addr = GetDeviceAddress(inputs, kDim0); + T *batch_output_addr = GetDeviceAddress(outputs, kDim0); + int *batch_piv_output_addr = nullptr; if (pivot_on_) { - piv_output_addr = GetDeviceAddress(outputs, kDim1); + batch_piv_output_addr = GetDeviceAddress(outputs, kDim1); } + int *batch_permutation_addr = GetDeviceAddress(outputs, kDim2); + int *info_output_addr = GetDeviceAddress(workspace, kDim0); - auto info_output_addr = GetDeviceAddress(outputs, kDim2); + size_t *dev_transpose_shape = GetDeviceAddress(workspace, kDim1); + size_t *dev_transpose_axis = GetDeviceAddress(workspace, kDim2); + constexpr size_t shape_2d = 2; + size_t host_transpose_shape[shape_2d] = {m_, n_}; + size_t host_transpose_axis[shape_2d] = {1, 0}; + T *dev_transpose_work = GetDeviceAddress(workspace, kDim3); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(dev_transpose_axis, host_transpose_axis, shape_2d * sizeof(size_t), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "malloc input shape workspace failed"); + + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_ * m_ * n_ * unit_size_, + cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in LUGpuKernel::Launch."); // 4. query working space of getrf if constexpr (std::is_same_v) { CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_, - cusolverDnSgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_), + cusolverDnSgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_), "cusolver query lu work size fail"); - if (cudaMalloc(reinterpret_cast(&d_work_), unit_size_ * lwork_) != cudaSuccess) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cusolver malloc work size fail"; - } - - CHECK_CUSOLVER_RET_WITH_EXCEPT( - kernel_node_, cusolverDnSgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr), - "cusolver lu fail"); - } else if constexpr (std::is_same_v) { CHECK_CUSOLVER_RET_WITH_EXCEPT(kernel_node_, - cusolverDnDgetrf_bufferSize(handle_, m_, m_, input_addr, lda_, &lwork_), + cusolverDnDgetrf_bufferSize(handle_, m_, n_, batch_output_addr, lda_, &lwork_), "cusolver query lu work size fail"); - // 5. malloc device working space of getrf - - if (cudaMalloc(reinterpret_cast(&d_work_), unit_size_ * lwork_) != cudaSuccess) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cusolver malloc work size fail"; - } - - // 6. solve to lu factorization according to cuSolver api, outputs have been written to input's matrix. - CHECK_CUSOLVER_RET_WITH_EXCEPT( - kernel_node_, cusolverDnDgetrf(handle_, m_, m_, input_addr, lda_, d_work_, piv_output_addr, info_output_addr), - "cusolver lu fail"); } else { MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now."; } - // 7. copy results from written input's matrix to output's matrix. - // if (cudaMemcpy(output_addr, input_addr, lda_ * m_ * unit_size_, cudaMemcpyDeviceToDevice) != cudaSuccess) { - // MS_LOG(EXCEPTION) << "memcpy lu output fail."; - // } - MatrixCopy(input_addr, output_addr, lda_ * m_, reinterpret_cast(stream_ptr)); - if (d_work_) { - cudaFree(d_work_); + // 5. malloc device working space of getrf + d_work_ = reinterpret_cast(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_)); + for (size_t batch = 0; batch < batch_; ++batch) { + T *output_addr = batch_output_addr + batch * m_ * n_; + int *permutation_addr = batch_permutation_addr + batch * k_ * k_; + int *piv_output_addr = batch_piv_output_addr + batch * k_; + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(dev_transpose_shape, host_transpose_shape, shape_2d * sizeof(size_t), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "malloc input shape workspace failed"); + + CalTranspose(m_ * n_, output_addr, dev_transpose_shape, dev_transpose_axis, shape_2d, dev_transpose_work, + reinterpret_cast(stream_ptr)); + + // 6.lu factorization according to cuSolver api, outputs have been written to input's matrix. + if constexpr (std::is_same_v) { + CHECK_CUSOLVER_RET_WITH_EXCEPT( + kernel_node_, + cusolverDnSgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr), + "cusolver lu fail"); + } else if constexpr (std::is_same_v) { + // 6.lu factorization according to cuSolver api, outputs have been written to input's matrix. + CHECK_CUSOLVER_RET_WITH_EXCEPT( + kernel_node_, + cusolverDnDgetrf(handle_, m_, n_, dev_transpose_work, lda_, d_work_, piv_output_addr, info_output_addr), + "cusolver lu fail"); + } else { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now."; + } + + size_t host_wk_transpose_shape[shape_2d] = {n_, m_}; + cudaMemcpyAsync(dev_transpose_shape, host_wk_transpose_shape, shape_2d * sizeof(size_t), cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)); + CalTranspose(m_ * n_, dev_transpose_work, dev_transpose_shape, dev_transpose_axis, shape_2d, output_addr, + reinterpret_cast(stream_ptr)); + std::vector host_permuted(k_, 0); + std::vector host_pivots(k_, 0); + std::vector host_permutation(k_ * k_, 0); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(host_pivots.data(), piv_output_addr, sizeof(int) * k_, + cudaMemcpyDeviceToHost, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in LUGpuKernel::Launch copy pivots to host."); + + // cal pivots && permutation major by row. + for (size_t i = 0; i < k_; ++i) { + host_pivots[i] -= 1; + host_permuted[i] = i; + } + for (size_t i = 0; i < k_; ++i) { + int tmp_value = host_permuted[i]; + host_permuted[i] = host_permuted[host_pivots[i]]; + host_permuted[host_pivots[i]] = tmp_value; + } + // gpu default is P.A = LU, so here is col swap. + for (size_t i = 0; i < k_; ++i) { + host_permutation[host_permuted[i] * k_ + i] = 1; + } + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(permutation_addr, host_permutation.data(), sizeof(int) * k_ * k_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in LUGpuKernel::Launch copy permutation matrix."); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(piv_output_addr, host_pivots.data(), sizeof(int) * k_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in LUGpuKernel::Launch copy pivots array."); } + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(d_work_); return true; } @@ -125,48 +176,65 @@ class LUGpuKernel : public GpuKernel { private: bool InitInputSize(const std::vector &in_shape) { - if (in_shape.size() == kLuDefaultShape) { - lu_row_ = in_shape.at(kDim0); - lu_col_ = lu_row_; - } else if (in_shape.size() == kLuNormalShape) { - lu_row_ = in_shape.at(kDim0); - lu_col_ = in_shape.at(kDim1); - } else { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of input only should be 1 or 2"; - return false; + constexpr size_t lu_min_dim = 1; + constexpr size_t lu_max_dim = 3; + if (in_shape.size() < lu_min_dim || in_shape.size() > lu_max_dim) { + MS_LOG_EXCEPTION << kernel_name_ << "shape is " << in_shape.size() << " which is invalid."; } - if (lu_row_ != lu_col_) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of input should be square matrix"; - return false; + if (in_shape.size() == lu_max_dim) { + batch_ = in_shape.front(); + lu_row_ = in_shape.at(lu_min_dim); + lu_col_ = in_shape.at(lu_max_dim - 1); + } else { + batch_ = 1; + lu_row_ = in_shape.front(); + lu_col_ = in_shape.at(lu_min_dim); } // set matrix row or col to be lead dimension m_ = SizeToInt(lu_row_); + n_ = SizeToInt(lu_col_); + k_ = std::min(lu_row_, lu_col_); lda_ = m_; - ldb_ = m_; + ldb_ = n_; InitSizeLists(); return true; } void InitSizeLists() override { - size_t input_size = lda_ * m_ * unit_size_; + size_t input_size = batch_ * lu_row_ * lu_col_ * unit_size_; input_size_list_.push_back(input_size); - size_t output_size = lda_ * m_ * unit_size_; + size_t output_size = batch_ * lu_row_ * lu_col_ * unit_size_; + size_t output_piv_size = 0; - size_t output_info_size = sizeof(int); if (pivot_on_) { - output_piv_size = m_ * sizeof(int); + output_piv_size = batch_ * k_ * sizeof(int); } + size_t output_permutation_size = batch_ * k_ * k_ * sizeof(int); output_size_list_.resize(kDim3); output_size_list_[kDim0] = output_size; output_size_list_[kDim1] = output_piv_size; - output_size_list_[kDim2] = output_info_size; + output_size_list_[kDim2] = output_permutation_size; + + // a device addr to place lu factor return code + workspace_size_list_.push_back(sizeof(int)); + + // transpose 2d matrix scalar args workspace + constexpr size_t shape_2d = 2; + workspace_size_list_.push_back(shape_2d * sizeof(size_t)); + workspace_size_list_.push_back(shape_2d * sizeof(size_t)); + + // transpose workspace + workspace_size_list_.push_back(m_ * n_ * unit_size_); } size_t unit_size_{sizeof(T)}; + size_t batch_{1}; size_t lu_row_{0}; size_t lu_col_{0}; + size_t k_{0}; size_t m_{0}; + size_t n_{0}; size_t lda_{0}; size_t ldb_{0}; int lwork_{0}; diff --git a/mindspore/python/mindspore/scipy/linalg.py b/mindspore/python/mindspore/scipy/linalg.py index 56808d40bd0..a92689b8cd5 100755 --- a/mindspore/python/mindspore/scipy/linalg.py +++ b/mindspore/python/mindspore/scipy/linalg.py @@ -514,7 +514,7 @@ def lu_factor(a, overwrite_a=False, check_finite=True): [ 7.14285714e-01, 1.20000000e-01, -1.04000000e+00, 3.08000000e+00], [ 7.14285714e-01, -4.40000000e-01, -4.61538462e-01, 7.46153846e+00]]) >>> piv - Tensor(shape=[4], dtype=Int32, value= [2, 0, 3, 1]) + Tensor(shape=[4], dtype=Int32, value= [2, 2, 3, 3]) """ if F.dtype(a) not in float_types: a = F.cast(a, mstype.float32) @@ -592,9 +592,11 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True): m_lu, _, p = msp_lu(a) m = a.shape[-2] n = a.shape[-1] + if m > n: + _raise_value_error("last two dimensions of LU decomposition must be row less or equal to col.") k = min(m, n) a_dtype = a.dtype - l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype) + l = mnp.tril(m_lu, -1)[..., :k] + mnp.eye(m, k, dtype=a_dtype) u = mnp.triu(m_lu)[:k, :] if permute_l: return mnp.dot(p, l), u @@ -642,7 +644,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): check_lu_shape(m_lu, b) # here permutation array has been calculated, just use it. # 2. calculate permutation - permutation = lu_pivots_to_permutation(pivots, len(pivots)) + permutation = lu_pivots_to_permutation(pivots, pivots.size) # 3. rhs_vector rhs_vector = m_lu.ndim == b.ndim + 1 x = lu_solve_core(m_lu, permutation, b, trans) diff --git a/mindspore/python/mindspore/scipy/ops.py b/mindspore/python/mindspore/scipy/ops.py index d2316a3b9e8..628645b9212 100644 --- a/mindspore/python/mindspore/scipy/ops.py +++ b/mindspore/python/mindspore/scipy/ops.py @@ -301,13 +301,16 @@ class LU(PrimitiveWithInfer): x_shape = list(x['shape']) x_dtype = x['dtype'] ndim = len(x_shape) - if ndim in (1, 2): - permutation_shape = (x_shape[0], x_shape[0]) + k_shape = min(x_shape[0], x_shape[1]) + permutation_shape = (k_shape, k_shape) + pivots_shape = (1, k_shape) else: - permutation_shape = (x_shape[0], x_shape[1], x_shape[1]) + k_shape = min(x_shape[1], x_shape[2]) + permutation_shape = (x_shape[0], k_shape, k_shape) + pivots_shape = (x_shape[0], 1, k_shape) output = { - 'shape': (x_shape, permutation_shape[:-1], permutation_shape), + 'shape': (x_shape, pivots_shape, permutation_shape), 'dtype': (x_dtype, mstype.int32, mstype.int32), 'value': None } diff --git a/tests/st/scipy_st/test_linalg.py b/tests/st/scipy_st/test_linalg.py index 60f984633f7..27e30cc97c1 100644 --- a/tests/st/scipy_st/test_linalg.py +++ b/tests/st/scipy_st/test_linalg.py @@ -217,9 +217,10 @@ def test_eigh_complex(n: int, dtype): @pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -@pytest.mark.parametrize('shape', [(4, 4), (4, 5), (10, 5), (20, 20)]) +@pytest.mark.parametrize('shape', [(4, 4), (4, 5), (5, 10), (20, 20)]) @pytest.mark.parametrize('dtype', [onp.float32, onp.float64]) def test_lu(shape: (int, int), dtype): """ @@ -238,6 +239,36 @@ def test_lu(shape: (int, int), dtype): assert onp.allclose(m_u.asnumpy(), s_u, rtol=rtol, atol=atol) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5)]) +@pytest.mark.parametrize('dtype', [onp.float32, onp.float64]) +def test_batch_lu(shape: (int, int, int), dtype): + """ + Feature: ALL To ALL + Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1] + Expectation: the result match to scipy + """ + b_a = create_random_rank_matrix(shape, dtype) + b_s_p = list() + b_s_l = list() + b_s_u = list() + for a in b_a: + s_p, s_l, s_u = osp.linalg.lu(a) + b_s_p.append(s_p) + b_s_l.append(s_l) + b_s_u.append(s_u) + tensor_b_a = Tensor(onp.array(b_a)) + b_m_p, b_m_l, b_m_u = msp.linalg.lu(tensor_b_a) + rtol = 1.e-5 + atol = 1.e-5 + assert onp.allclose(b_m_p.asnumpy(), b_s_p, rtol=rtol, atol=atol) + assert onp.allclose(b_m_l.asnumpy(), b_s_l, rtol=rtol, atol=atol) + assert onp.allclose(b_m_u.asnumpy(), b_s_u, rtol=rtol, atol=atol) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard