!28923 motify lu pivots's shape

Merge pull request !28923 from zhuzhongrui/pub_master
This commit is contained in:
i-robot 2022-01-13 09:14:59 +00:00 committed by Gitee
commit 05c18009c3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 52 additions and 44 deletions

View File

@ -38,19 +38,30 @@ constexpr int kZeroThreshold = INT32_MIN;
template <typename T>
void LUCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
constexpr size_t lu_min_dim = 1;
constexpr size_t lu_max_dim = 3;
if (shape.size() < lu_min_dim || shape.size() > lu_max_dim) {
if (shape.size() <= lu_min_dim) {
MS_LOG_EXCEPTION << kernel_name_ << "shape is " << shape.size() << " which is invalid.";
}
if (shape.size() == lu_max_dim) {
batch_ = shape.front();
*row = shape.at(lu_min_dim);
*col = shape.at(lu_max_dim - 1);
return;
constexpr size_t lu_reverse_row_dim = 2;
*row = shape.at(shape.size() - lu_reverse_row_dim);
*col = shape.at(shape.size() - 1);
batch_size_ = lu_min_dim;
for (int batch = 0; batch < static_cast<int>(shape.size() - lu_reverse_row_dim); ++batch) {
batch_size_ *= shape.at(batch);
}
}
template <typename T>
void LUCPUKernel<T>::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
constexpr size_t pivot_min_dim = 1;
if (shape.size() < pivot_min_dim) {
MS_LOG_EXCEPTION << kernel_name_ << "pivots shape is " << shape.size() << " which is invalid.";
}
*row = 1;
if (shape.size() == pivot_min_dim) {
*col = shape.front();
} else {
*col = shape.back();
}
batch_ = 1;
*row = shape.front();
*col = shape.at(lu_min_dim);
}
template <typename T>
@ -66,10 +77,10 @@ void LUCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
InitMatrixInfo(a_shape, &a_row_, &a_col_);
auto lu_shape = AnfAlgo::GetOutputInferShape(kernel_node, kLuIndex);
InitMatrixInfo(lu_shape, &lu_row_, &lu_col_);
auto pivots_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPivotsIndex);
InitMatrixInfo(pivots_shape, &pivots_row_, &pivots_col_);
auto permutation_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPermutationIndex);
InitMatrixInfo(permutation_shape, &permutation_row_, &permutation_col_);
auto pivots_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPivotsIndex);
InitPivotVecInfo(pivots_shape, &pivots_row_, &pivots_col_);
}
template <typename T>
@ -124,7 +135,7 @@ bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
int *batch_permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr);
T *lu_ori_wk = reinterpret_cast<T *>(workspace[kLuIndex]->addr);
T *lu_trans_wk = reinterpret_cast<T *>(workspace[kPivotsIndex]->addr);
for (size_t batch = 0; batch < batch_; ++batch) {
for (size_t batch = 0; batch < batch_size_; ++batch) {
T *a_value = batch_a_value + batch * a_row_ * a_col_;
T *lu_value = batch_lu_value + batch * lu_row_ * lu_col_;
// pivots permutation value

View File

@ -34,11 +34,12 @@ class LUCPUKernel : public CPUKernel {
private:
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitInputOutputSize(const CNodePtr &kernel_node) override;
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, int *pivots, size_t k, size_t rows);
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
size_t batch_{1};
size_t batch_size_{1};
size_t a_row_{1};
size_t a_col_{1};
size_t lu_row_{1};

View File

@ -68,7 +68,7 @@ class LUGpuKernel : public GpuKernel {
"malloc input shape workspace failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_ * m_ * n_ * unit_size_,
cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_size_ * m_ * n_ * unit_size_,
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in LUGpuKernel::Launch.");
@ -87,7 +87,7 @@ class LUGpuKernel : public GpuKernel {
}
// 5. malloc device working space of getrf
d_work_ = reinterpret_cast<T *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_));
for (size_t batch = 0; batch < batch_; ++batch) {
for (size_t batch = 0; batch < batch_size_; ++batch) {
T *output_addr = batch_output_addr + batch * m_ * n_;
int *permutation_addr = batch_permutation_addr + batch * k_ * k_;
int *piv_output_addr = batch_piv_output_addr + batch * k_;
@ -177,18 +177,15 @@ class LUGpuKernel : public GpuKernel {
private:
bool InitInputSize(const std::vector<size_t> &in_shape) {
constexpr size_t lu_min_dim = 1;
constexpr size_t lu_max_dim = 3;
if (in_shape.size() < lu_min_dim || in_shape.size() > lu_max_dim) {
MS_LOG_EXCEPTION << kernel_name_ << "shape is " << in_shape.size() << " which is invalid.";
if (in_shape.size() <= lu_min_dim) {
MS_LOG_EXCEPTION << kernel_name_ << " input shape is " << in_shape.size() << " which is invalid.";
}
if (in_shape.size() == lu_max_dim) {
batch_ = in_shape.front();
lu_row_ = in_shape.at(lu_min_dim);
lu_col_ = in_shape.at(lu_max_dim - 1);
} else {
batch_ = 1;
lu_row_ = in_shape.front();
lu_col_ = in_shape.at(lu_min_dim);
constexpr size_t lu_reverse_row_dim = 2;
lu_row_ = in_shape.at(in_shape.size() - lu_reverse_row_dim);
lu_col_ = in_shape.at(in_shape.size() - 1);
batch_size_ = lu_min_dim;
for (int batch = 0; batch < static_cast<int>(in_shape.size() - lu_reverse_row_dim); ++batch) {
batch_size_ *= in_shape.at(batch);
}
// set matrix row or col to be lead dimension
m_ = SizeToInt(lu_row_);
@ -201,16 +198,16 @@ class LUGpuKernel : public GpuKernel {
}
void InitSizeLists() override {
size_t input_size = batch_ * lu_row_ * lu_col_ * unit_size_;
size_t input_size = batch_size_ * lu_row_ * lu_col_ * unit_size_;
input_size_list_.push_back(input_size);
size_t output_size = batch_ * lu_row_ * lu_col_ * unit_size_;
size_t output_size = batch_size_ * lu_row_ * lu_col_ * unit_size_;
size_t output_piv_size = 0;
if (pivot_on_) {
output_piv_size = batch_ * k_ * sizeof(int);
output_piv_size = batch_size_ * k_ * sizeof(int);
}
size_t output_permutation_size = batch_ * k_ * k_ * sizeof(int);
size_t output_permutation_size = batch_size_ * k_ * k_ * sizeof(int);
output_size_list_.resize(kDim3);
output_size_list_[kDim0] = output_size;
output_size_list_[kDim1] = output_piv_size;
@ -229,7 +226,7 @@ class LUGpuKernel : public GpuKernel {
}
size_t unit_size_{sizeof(T)};
size_t batch_{1};
size_t batch_size_{1};
size_t lu_row_{0};
size_t lu_col_{0};
size_t k_{0};

View File

@ -300,15 +300,9 @@ class LU(PrimitiveWithInfer):
def __infer__(self, x):
x_shape = list(x['shape'])
x_dtype = x['dtype']
ndim = len(x_shape)
if ndim in (1, 2):
k_shape = min(x_shape[0], x_shape[1])
permutation_shape = (k_shape, k_shape)
pivots_shape = (1, k_shape)
else:
k_shape = min(x_shape[1], x_shape[2])
permutation_shape = (x_shape[0], k_shape, k_shape)
pivots_shape = (x_shape[0], 1, k_shape)
k_shape = min(x_shape[-1], x_shape[-2])
permutation_shape = x_shape[:-2] + [k_shape, k_shape]
pivots_shape = x_shape[:-2] + [k_shape]
output = {
'shape': (x_shape, pivots_shape, permutation_shape),
'dtype': (x_dtype, mstype.int32, mstype.int32),

View File

@ -243,9 +243,9 @@ def test_lu(shape: (int, int), dtype):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5)])
@pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5), (2, 3, 4, 5)])
@pytest.mark.parametrize('dtype', [onp.float32, onp.float64])
def test_batch_lu(shape: (int, int, int), dtype):
def test_batch_lu(shape, dtype):
"""
Feature: ALL To ALL
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
@ -255,13 +255,18 @@ def test_batch_lu(shape: (int, int, int), dtype):
b_s_p = list()
b_s_l = list()
b_s_u = list()
for a in b_a:
tmp = onp.zeros(b_a.shape[:-2])
for index, _ in onp.ndenumerate(tmp):
a = b_a[index]
s_p, s_l, s_u = osp.linalg.lu(a)
b_s_p.append(s_p)
b_s_l.append(s_l)
b_s_u.append(s_u)
tensor_b_a = Tensor(onp.array(b_a))
b_m_p, b_m_l, b_m_u = msp.linalg.lu(tensor_b_a)
b_s_p = onp.asarray(b_s_p).reshape(b_m_p.shape)
b_s_l = onp.asarray(b_s_l).reshape(b_m_l.shape)
b_s_u = onp.asarray(b_s_u).reshape(b_m_u.shape)
rtol = 1.e-5
atol = 1.e-5
assert onp.allclose(b_m_p.asnumpy(), b_s_p, rtol=rtol, atol=atol)