From da492e82d9e1bce80774b9509d718b6fced79513 Mon Sep 17 00:00:00 2001 From: z00512249 Date: Wed, 12 Jan 2022 10:49:43 +0800 Subject: [PATCH] motify lu pivots's shape --- .../cpu/eigen/lu_cpu_kernel.cc | 37 ++++++++++++------- .../kernel_compiler/cpu/eigen/lu_cpu_kernel.h | 3 +- .../kernel_compiler/gpu/math/lu_gpu_kernel.h | 33 ++++++++--------- mindspore/python/mindspore/scipy/ops.py | 12 ++---- tests/st/scipy_st/test_linalg.py | 11 ++++-- 5 files changed, 52 insertions(+), 44 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc index 50192376dd7..0ac914b6a55 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc @@ -38,19 +38,30 @@ constexpr int kZeroThreshold = INT32_MIN; template void LUCPUKernel::InitMatrixInfo(const std::vector &shape, size_t *row, size_t *col) { constexpr size_t lu_min_dim = 1; - constexpr size_t lu_max_dim = 3; - if (shape.size() < lu_min_dim || shape.size() > lu_max_dim) { + if (shape.size() <= lu_min_dim) { MS_LOG_EXCEPTION << kernel_name_ << "shape is " << shape.size() << " which is invalid."; } - if (shape.size() == lu_max_dim) { - batch_ = shape.front(); - *row = shape.at(lu_min_dim); - *col = shape.at(lu_max_dim - 1); - return; + constexpr size_t lu_reverse_row_dim = 2; + *row = shape.at(shape.size() - lu_reverse_row_dim); + *col = shape.at(shape.size() - 1); + batch_size_ = lu_min_dim; + for (int batch = 0; batch < static_cast(shape.size() - lu_reverse_row_dim); ++batch) { + batch_size_ *= shape.at(batch); + } +} + +template +void LUCPUKernel::InitPivotVecInfo(const std::vector &shape, size_t *row, size_t *col) { + constexpr size_t pivot_min_dim = 1; + if (shape.size() < pivot_min_dim) { + MS_LOG_EXCEPTION << kernel_name_ << "pivots shape is " << shape.size() << " which is invalid."; + } + *row = 1; + if (shape.size() == pivot_min_dim) { + *col = shape.front(); + } else { + *col = shape.back(); } - batch_ = 1; - *row = shape.front(); - *col = shape.at(lu_min_dim); } template @@ -66,10 +77,10 @@ void LUCPUKernel::InitKernel(const CNodePtr &kernel_node) { InitMatrixInfo(a_shape, &a_row_, &a_col_); auto lu_shape = AnfAlgo::GetOutputInferShape(kernel_node, kLuIndex); InitMatrixInfo(lu_shape, &lu_row_, &lu_col_); - auto pivots_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPivotsIndex); - InitMatrixInfo(pivots_shape, &pivots_row_, &pivots_col_); auto permutation_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPermutationIndex); InitMatrixInfo(permutation_shape, &permutation_row_, &permutation_col_); + auto pivots_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPivotsIndex); + InitPivotVecInfo(pivots_shape, &pivots_row_, &pivots_col_); } template @@ -124,7 +135,7 @@ bool LUCPUKernel::Launch(const std::vector &inputs, int *batch_permutation_value = reinterpret_cast(outputs[kPermutationIndex]->addr); T *lu_ori_wk = reinterpret_cast(workspace[kLuIndex]->addr); T *lu_trans_wk = reinterpret_cast(workspace[kPivotsIndex]->addr); - for (size_t batch = 0; batch < batch_; ++batch) { + for (size_t batch = 0; batch < batch_size_; ++batch) { T *a_value = batch_a_value + batch * a_row_ * a_col_; T *lu_value = batch_lu_value + batch * lu_row_ * lu_col_; // pivots permutation value diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h index 43ecdadd8f5..2b8bfa974d2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h @@ -34,11 +34,12 @@ class LUCPUKernel : public CPUKernel { private: void InitMatrixInfo(const std::vector &shape, size_t *row, size_t *col); + void InitPivotVecInfo(const std::vector &shape, size_t *row, size_t *col); void InitInputOutputSize(const CNodePtr &kernel_node) override; T GetPermutatedValue(const T *lu_value, const std::vector &per_value, size_t i, size_t j); bool UpdateMajorPermutation(T *lu_value, std::vector *const per_value, int *pivots, size_t k, size_t rows); void SetPermutatedValue(T *lu_value, const std::vector &per_value, size_t i, size_t j, const T &value); - size_t batch_{1}; + size_t batch_size_{1}; size_t a_row_{1}; size_t a_col_{1}; size_t lu_row_{1}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h index d48ce13e5a9..c53fc8d9f45 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/lu_gpu_kernel.h @@ -68,7 +68,7 @@ class LUGpuKernel : public GpuKernel { "malloc input shape workspace failed"); CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, - cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_ * m_ * n_ * unit_size_, + cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_size_ * m_ * n_ * unit_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), "cudaMemcpyAsync failed in LUGpuKernel::Launch."); @@ -87,7 +87,7 @@ class LUGpuKernel : public GpuKernel { } // 5. malloc device working space of getrf d_work_ = reinterpret_cast(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_)); - for (size_t batch = 0; batch < batch_; ++batch) { + for (size_t batch = 0; batch < batch_size_; ++batch) { T *output_addr = batch_output_addr + batch * m_ * n_; int *permutation_addr = batch_permutation_addr + batch * k_ * k_; int *piv_output_addr = batch_piv_output_addr + batch * k_; @@ -177,18 +177,15 @@ class LUGpuKernel : public GpuKernel { private: bool InitInputSize(const std::vector &in_shape) { constexpr size_t lu_min_dim = 1; - constexpr size_t lu_max_dim = 3; - if (in_shape.size() < lu_min_dim || in_shape.size() > lu_max_dim) { - MS_LOG_EXCEPTION << kernel_name_ << "shape is " << in_shape.size() << " which is invalid."; + if (in_shape.size() <= lu_min_dim) { + MS_LOG_EXCEPTION << kernel_name_ << " input shape is " << in_shape.size() << " which is invalid."; } - if (in_shape.size() == lu_max_dim) { - batch_ = in_shape.front(); - lu_row_ = in_shape.at(lu_min_dim); - lu_col_ = in_shape.at(lu_max_dim - 1); - } else { - batch_ = 1; - lu_row_ = in_shape.front(); - lu_col_ = in_shape.at(lu_min_dim); + constexpr size_t lu_reverse_row_dim = 2; + lu_row_ = in_shape.at(in_shape.size() - lu_reverse_row_dim); + lu_col_ = in_shape.at(in_shape.size() - 1); + batch_size_ = lu_min_dim; + for (int batch = 0; batch < static_cast(in_shape.size() - lu_reverse_row_dim); ++batch) { + batch_size_ *= in_shape.at(batch); } // set matrix row or col to be lead dimension m_ = SizeToInt(lu_row_); @@ -201,16 +198,16 @@ class LUGpuKernel : public GpuKernel { } void InitSizeLists() override { - size_t input_size = batch_ * lu_row_ * lu_col_ * unit_size_; + size_t input_size = batch_size_ * lu_row_ * lu_col_ * unit_size_; input_size_list_.push_back(input_size); - size_t output_size = batch_ * lu_row_ * lu_col_ * unit_size_; + size_t output_size = batch_size_ * lu_row_ * lu_col_ * unit_size_; size_t output_piv_size = 0; if (pivot_on_) { - output_piv_size = batch_ * k_ * sizeof(int); + output_piv_size = batch_size_ * k_ * sizeof(int); } - size_t output_permutation_size = batch_ * k_ * k_ * sizeof(int); + size_t output_permutation_size = batch_size_ * k_ * k_ * sizeof(int); output_size_list_.resize(kDim3); output_size_list_[kDim0] = output_size; output_size_list_[kDim1] = output_piv_size; @@ -229,7 +226,7 @@ class LUGpuKernel : public GpuKernel { } size_t unit_size_{sizeof(T)}; - size_t batch_{1}; + size_t batch_size_{1}; size_t lu_row_{0}; size_t lu_col_{0}; size_t k_{0}; diff --git a/mindspore/python/mindspore/scipy/ops.py b/mindspore/python/mindspore/scipy/ops.py index 628645b9212..ef35972d03e 100644 --- a/mindspore/python/mindspore/scipy/ops.py +++ b/mindspore/python/mindspore/scipy/ops.py @@ -300,15 +300,9 @@ class LU(PrimitiveWithInfer): def __infer__(self, x): x_shape = list(x['shape']) x_dtype = x['dtype'] - ndim = len(x_shape) - if ndim in (1, 2): - k_shape = min(x_shape[0], x_shape[1]) - permutation_shape = (k_shape, k_shape) - pivots_shape = (1, k_shape) - else: - k_shape = min(x_shape[1], x_shape[2]) - permutation_shape = (x_shape[0], k_shape, k_shape) - pivots_shape = (x_shape[0], 1, k_shape) + k_shape = min(x_shape[-1], x_shape[-2]) + permutation_shape = x_shape[:-2] + [k_shape, k_shape] + pivots_shape = x_shape[:-2] + [k_shape] output = { 'shape': (x_shape, pivots_shape, permutation_shape), 'dtype': (x_dtype, mstype.int32, mstype.int32), diff --git a/tests/st/scipy_st/test_linalg.py b/tests/st/scipy_st/test_linalg.py index 27e30cc97c1..4b8c644d478 100644 --- a/tests/st/scipy_st/test_linalg.py +++ b/tests/st/scipy_st/test_linalg.py @@ -243,9 +243,9 @@ def test_lu(shape: (int, int), dtype): @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5)]) +@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5), (2, 3, 4, 5)]) @pytest.mark.parametrize('dtype', [onp.float32, onp.float64]) -def test_batch_lu(shape: (int, int, int), dtype): +def test_batch_lu(shape, dtype): """ Feature: ALL To ALL Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1] @@ -255,13 +255,18 @@ def test_batch_lu(shape: (int, int, int), dtype): b_s_p = list() b_s_l = list() b_s_u = list() - for a in b_a: + tmp = onp.zeros(b_a.shape[:-2]) + for index, _ in onp.ndenumerate(tmp): + a = b_a[index] s_p, s_l, s_u = osp.linalg.lu(a) b_s_p.append(s_p) b_s_l.append(s_l) b_s_u.append(s_u) tensor_b_a = Tensor(onp.array(b_a)) b_m_p, b_m_l, b_m_u = msp.linalg.lu(tensor_b_a) + b_s_p = onp.asarray(b_s_p).reshape(b_m_p.shape) + b_s_l = onp.asarray(b_s_l).reshape(b_m_l.shape) + b_s_u = onp.asarray(b_s_u).reshape(b_m_u.shape) rtol = 1.e-5 atol = 1.e-5 assert onp.allclose(b_m_p.asnumpy(), b_s_p, rtol=rtol, atol=atol)