!29625 Add GPU matrix_diag_part ops, optimize related testcase

Merge pull request !29625 from wuwenbing/dev
This commit is contained in:
i-robot 2022-01-29 06:40:11 +00:00 committed by Gitee
commit 776d937266
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 498 additions and 161 deletions

View File

@ -41,8 +41,8 @@ __global__ void MatrixBandPartKernel(const size_t size, const T *input_matrix_ad
template <typename T> template <typename T>
void MatrixBandPart(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n, const int64_t l, void MatrixBandPart(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n, const int64_t l,
const int64_t u, T *output_addr, cudaStream_t cuda_stream) { const int64_t u, T *output_addr, cudaStream_t cuda_stream) {
MatrixBandPartKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_matrix_addr, m, n, l, u, MatrixBandPartKernel<<<GET_BLOCKS(size), GET_THREADS_MAXSIZE(size), 0, cuda_stream>>>(size, input_matrix_addr, m, n,
output_addr, cuda_stream); l, u, output_addr, cuda_stream);
} }
template void MatrixBandPart<int32_t>(const size_t size, const int32_t *input_matrix_addr, const size_t m, template void MatrixBandPart<int32_t>(const size_t size, const int32_t *input_matrix_addr, const size_t m,

View File

@ -0,0 +1,78 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "matrix_diag_part_impl.cuh"
#include <cuda_runtime.h>
#include <algorithm>
#include "utils/complex.h"
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename T>
__global__ void MatrixDiagPartKernel(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n,
const int64_t l, const int64_t u, const size_t num_diags,
const size_t max_diag_len, const int64_t la, const int64_t ua, T *padding_value,
T *output_addr, cudaStream_t cuda_stream) {
int64_t dest_inner_matrix_len = num_diags * max_diag_len;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
const int64_t i = pos / dest_inner_matrix_len;
const int64_t j = u - (pos % dest_inner_matrix_len) / max_diag_len;
const int64_t k = (pos % dest_inner_matrix_len) % max_diag_len;
int64_t current_diag_len = j >= 0 ? min(n - j, m) : min(m + j, n);
int64_t current_pad_len = max_diag_len - current_diag_len;
// Pad left by default (0:right, 1:left)
bool pad_left = (la == 0 && j > 0) || (ua == 0 && j < 0);
// Set none-padding values, l means current diag col index
// Source pos, k offset, only effective when pad left
int64_t k_offset = (pad_left && k >= current_pad_len) ? k - current_pad_len : k;
// Calculate source offset row/col offset
size_t row_index = j >= 0 ? j + k_offset : k_offset;
size_t col_index = j >= 0 ? k_offset : k_offset - j;
size_t source_offset = i * m * n + col_index * n + row_index;
// If current pos need pad, then the value is pad value
bool current_pad_flag = (pad_left && k < current_pad_len) || (!pad_left && k >= current_diag_len);
T current_pad_value = current_pad_flag ? *padding_value : *(input_matrix_addr + source_offset);
int64_t j_index = u - j;
size_t dest_offset = dest_inner_matrix_len * i + j_index * max_diag_len + k;
*(output_addr + dest_offset) = current_pad_value;
}
}
template <typename T>
void MatrixDiagPart(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n, const int64_t l,
const int64_t u, const size_t num_diags, const size_t max_diag_len, const int64_t la,
const int64_t ua, T *padding_value, T *output_addr, cudaStream_t cuda_stream) {
MatrixDiagPartKernel<<<GET_BLOCKS(size), GET_THREADS_MAXSIZE(size), 0, cuda_stream>>>(
size, input_matrix_addr, m, n, l, u, num_diags, max_diag_len, la, ua, padding_value, output_addr, cuda_stream);
}
template void MatrixDiagPart<int32_t>(const size_t size, const int32_t *input_matrix_addr, const size_t m,
const size_t n, const int64_t l, const int64_t u, const size_t num_diags,
const size_t max_diag_len, const int64_t la, const int64_t ua,
int32_t *padding_value, int32_t *output_addr, cudaStream_t cuda_stream);
template void MatrixDiagPart<int64_t>(const size_t size, const int64_t *input_matrix_addr, const size_t m,
const size_t n, const int64_t l, const int64_t u, const size_t num_diags,
const size_t max_diag_len, const int64_t la, const int64_t ua,
int64_t *padding_value, int64_t *output_addr, cudaStream_t cuda_stream);
template void MatrixDiagPart<float>(const size_t size, const float *input_matrix_addr, const size_t m, const size_t n,
const int64_t l, const int64_t u, const size_t num_diags, const size_t max_diag_len,
const int64_t la, const int64_t ua, float *padding_value, float *output_addr,
cudaStream_t cuda_stream);
template void MatrixDiagPart<double>(const size_t size, const double *input_matrix_addr, const size_t m, const size_t n,
const int64_t l, const int64_t u, const size_t num_diags,
const size_t max_diag_len, const int64_t la, const int64_t ua,
double *padding_value, double *output_addr, cudaStream_t cuda_stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_MATRIX_DIAG_PART_IMPL_CUH
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_MATRIX_DIAG_PART_IMPL_CUH
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void MatrixDiagPart(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n, const int64_t l,
const int64_t u, const size_t num_diags, const size_t max_diag_len, const int64_t la,
const int64_t ua, T *padding_value, T *output_addr, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_MATRIX_DIAG_PART_IMPL_CUH

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2022 Huawei Technologies Co., Ltd * Copyright 2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2022 Huawei Technologies Co., Ltd * Copyright 2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.

View File

@ -0,0 +1,50 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/matrix_diag_part_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(MatrixDiagPart,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
MatrixDiagPartGpuKernelMod, int32_t)
MS_REG_GPU_KERNEL_ONE(MatrixDiagPart,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
MatrixDiagPartGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(MatrixDiagPart,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MatrixDiagPartGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(MatrixDiagPart,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
MatrixDiagPartGpuKernelMod, double)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,146 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_DIAG_PART_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_DIAG_PART_GPU_KERNEL_H
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cusolverDn.h>
#include <cuda_runtime.h>
#include <vector>
#include <string>
#include <utility>
#include <algorithm>
#include "utils/complex.h"
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_diag_part_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename T>
class MatrixDiagPartGpuKernelMod : public NativeGpuKernelMod {
public:
MatrixDiagPartGpuKernelMod() : is_null_input_(false) { ResetResource(); }
~MatrixDiagPartGpuKernelMod() = default;
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (is_null_input_) {
InitSizeLists();
return true;
}
dim_size_ = shapes_.size();
if (shapes_.size() < kDim2) {
MS_LOG(EXCEPTION) << "Wrong array shape, matrix shape should not less than 2.";
}
m_ = shapes_[dim_size_ - kDim2];
n_ = shapes_[dim_size_ - kDim1];
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
out_range_size_ *= shapes_[i];
}
matrix_size_ = out_range_size_ * m_ * n_;
InitSizeLists();
alignment_ = GetAlignments(AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAlignment));
kernel_node_ = kernel_node;
return true;
}
void PostExecute() override {
auto output_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node_.lock(), 0);
output_shape[shapes_.size() - kDim1] = max_diag_len_;
// If the out shape m' * n', the m' dimension is 1, then remove this dimension
output_shape[shapes_.size() - kDim2] = num_diags_;
if (num_diags_ == 1) {
output_shape.erase(output_shape.begin() + shapes_.size() - kDim2);
}
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node_.lock(), 0);
AnfAlgo::SetOutputInferTypeAndShape({data_type}, {output_shape}, kernel_node_.lock().get());
}
void ResetResource() noexcept override {
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
auto input_matrix_addr = GetDeviceAddress<T>(inputs, kDim0);
auto d_k_range = GetDeviceAddress<int64_t>(inputs, kDim1);
auto padding_value = GetDeviceAddress<T>(inputs, kDim2);
auto output_matrix_addr = GetDeviceAddress<T>(outputs, kDim0);
int64_t k_range[kDim2]{0, 0};
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&k_range, d_k_range, kDim2 * sizeof(int64_t), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Copy input lower to host failed");
int64_t l = k_range[0];
int64_t u = k_range[1];
// New diagonal matrix m*n matrix, m dimension ;
if (l > u) {
MS_LOG(EXCEPTION) << "The k[1] must not less than k[0].";
}
u = std::min(u, static_cast<int64_t>(n_) - 1);
l = std::max(-(static_cast<int64_t>(m_) - 1), l);
num_diags_ = u - l + 1;
// New diagonal matrix m * n matrix, n dimension
max_diag_len_ = std::min(m_ + std::min(u, static_cast<int64_t>(0)), n_ + std::min(-l, static_cast<int64_t>(0)));
MatrixDiagPart(out_range_size_ * num_diags_ * max_diag_len_, input_matrix_addr, m_, n_, l, u, num_diags_,
max_diag_len_, alignment_.first, alignment_.second, padding_value, output_matrix_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(matrix_size_ * sizeof(T)); // Input
input_size_list_.push_back(kDim2 * sizeof(int64_t)); // k_range
input_size_list_.push_back(sizeof(T)); // padding_value
output_size_list_.push_back(matrix_size_ * sizeof(T)); // Output
}
private:
TypeId dtype_{kNumberTypeFloat32};
bool is_null_input_;
std::vector<size_t> shapes_{};
size_t dim_size_{1};
size_t matrix_size_{0};
size_t out_range_size_{1};
int64_t num_diags_{1};
int64_t max_diag_len_{1};
size_t m_{1};
size_t n_{1};
std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> alignment_{MatrixDiag::RIGHT, MatrixDiag::LEFT};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_DIAG_PART_GPU_KERNEL_H

View File

@ -30,6 +30,7 @@ namespace gpu {
class CudaCommon { class CudaCommon {
public: public:
inline int threads_num() const { return threads_per_block_; } inline int threads_num() const { return threads_per_block_; }
inline int threads_num(int size) const { return std::min(size, threads_per_block_); }
inline int major_sm() const { return major_sm_; } inline int major_sm() const { return major_sm_; }
inline float cuda_cap() const { return static_cast<float>(major_sm_ * 10 + minor_sm_) / 10.0; } inline float cuda_cap() const { return static_cast<float>(major_sm_ * 10 + minor_sm_) / 10.0; }
inline int blocks_num(const int total_threads) const { inline int blocks_num(const int total_threads) const {
@ -68,6 +69,7 @@ class CudaCommon {
}; };
#define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) #define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads)
#define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() #define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num()
#define GET_THREADS_MAXSIZE(size) mindspore::device::gpu::CudaCommon::GetInstance().threads_num(size)
#define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() #define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm()
#define GET_CUDA_CAP mindspore::device::gpu::CudaCommon::GetInstance().cuda_cap() #define GET_CUDA_CAP mindspore::device::gpu::CudaCommon::GetInstance().cuda_cap()
#define SHARED_MEM_PER_BLOCK mindspore::device::gpu::CudaCommon::GetInstance().share_memory_size() #define SHARED_MEM_PER_BLOCK mindspore::device::gpu::CudaCommon::GetInstance().share_memory_size()

View File

@ -16,169 +16,205 @@
import pytest import pytest
import mindspore.scipy as msp import mindspore.scipy as msp
from mindspore import context, Tensor from mindspore import context, Tensor
from tests.st.scipy_st.utils import match_matrix from tests.st.scipy_st.utils import match_array
aligndict = {0: "LEFT_RIGHT", 1: "LEFT_LEFT", 2: "RIGHT_LEFT", 3: "RIGHT_RIGHT"} aligndict = {0: "LEFT_RIGHT", 1: "LEFT_LEFT", 2: "RIGHT_LEFT", 3: "RIGHT_RIGHT"}
PAD_VALUE = -1 PAD_VALUE = -1
Adict = {(1, 1, 1): (([[[1]]]), {}), (1, 3, 3): (([[[8, 2, 1], [5, 3, 7], [0, 3, 4]]]),
{(-2, -2, 0): ([[0]]), (-2, -1, 3): ([[[5, 3], [-1, 0]]]),
(-2, 0, 2): ([[[8, 3, 4], [5, 3, -1], [0, -1, -1]]]),
(-2, 1, 3): ([[[-1, 2, 7], [8, 3, 4], [-1, 5, 3], [-1, -1, 0]]]),
(-2, 2, 0): (\
[[[1, -1, -1], [2, 7, -1], [8, 3, 4], [-1, 5, 3], [-1, -1, 0]]]),
(-1, -1, 2): ([[5, 3]]), (-1, 0, 1): ([[[8, 3, 4], [5, 3, -1]]]),
(-1, 1, 2): ([[[-1, 2, 7], [8, 3, 4], [5, 3, -1]]]),
(-1, 2, 3): ([[[-1, -1, 1], [-1, 2, 7], [8, 3, 4], [-1, 5, 3]]]),
(0, 0, 0): ([[8, 3, 4]]), (0, 1, 1): ([[[2, 7, -1], [8, 3, 4]]]),
(0, 2, 2): ([[[-1, -1, 1], [-1, 2, 7], [8, 3, 4]]]),
(1, 1, 2): ([[2, 7]]), (1, 2, 3): ([[[-1, 1], [2, 7]]])}),
(1, 1, 2): (([[[3, 2]]]), {}), (1, 3, 5): (([[[3, 5, 5, 2, 5], [0, 6, 2, 4, 7], [7, 3, 3, 6, 8]]]),
{(-2, -2, 0): ([[7]]), (-2, -1, 3): ([[[0, 3], [-1, 7]]]),
(-2, 0, 2): ([[[3, 6, 3], [0, 3, -1], [7, -1, -1]]]),
(-2, 1, 3): ([[[5, 2, 6], [3, 6, 3], [-1, 0, 3], [-1, -1, 7]]]),
(-2, 2, 0): (\
[[[5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3], [-1, -1, 7]]]),
(-2, 3, 1): ([
[[2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], [0, 3, -1],
[7, -1, -1]]]), (-2, 4, 2): ([\
[[-1, -1, 5], [-1, 2, 7], [5, 4, 8], [5, 2, 6], [3, 6, 3],
[0, 3, -1], [7, -1, -1]]]), (-1, -1, 2): ([[0, 3]]),
(-1, 0, 1): ([[[3, 6, 3], [0, 3, -1]]]),
(-1, 1, 2): ([[[5, 2, 6], [3, 6, 3], [0, 3, -1]]]),
(-1, 2, 3): ([[[5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3]]]),
(-1, 3, 0): (\
[[[2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3]]]),
(-1, 4, 1): ([
[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3],
[0, 3, -1]]]), (0, 0, 0): ([[3, 6, 3]]),
(0, 1, 1): ([[[5, 2, 6], [3, 6, 3]]]),
(0, 2, 2): ([[[5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
(0, 3, 3): ([[[-1, 2, 7], [5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
(0, 4, 0): (\
[[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
(1, 1, 2): ([[5, 2, 6]]), (1, 2, 3): ([[[5, 4, 8], [5, 2, 6]]]),
(1, 3, 0): ([[[2, 7, -1], [5, 4, 8], [5, 2, 6]]]),
(1, 4, 1): ([[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6]]])}),
(1, 2, 1): (([[[4], [1]]]), {(-1, -1, 2): ([[1]]), (-1, 0, 1): ([[[4], [1]]]), (0, 0, 0): ([[4]])}),
(1, 5, 3): (([[[7, 8, 8], [3, 5, 6], [0, 4, 4], [8, 4, 5], [0, 4, 6]]]),
{(-4, -4, 0): ([[0]]), (-4, -3, 3): ([[[8, 4], [-1, 0]]]),
(-4, -2, 2): ([[[0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
(-4, -1, 1): ([[[3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
(-4, 0, 0): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4], [-1, -1, 0]]]),
(-4, 1, 1): ([[[8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
(-4, 2, 2): (\
[[[-1, -1, 8], [-1, 8, 6], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
(-3, -3, 2): ([[8, 4]]), (-3, -2, 1): ([[[0, 4, 6], [8, 4, -1]]]),
(-3, -1, 0): ([[[3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
(-3, 0, 3): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
(-3, 1, 0): ([[[8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
(-3, 2, 1): ([[[8, -1, -1], [8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1]]]),
(-2, -2, 0): ([[0, 4, 6]]), (-2, -1, 3): ([[[3, 4, 5], [0, 4, 6]]]),
(-2, 0, 2): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
(-2, 1, 3): ([[[-1, 8, 6], [7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
(-2, 2, 0): ([[[8, -1, -1], [8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
(-1, -1, 2): ([[3, 4, 5]]), (-1, 0, 1): ([[[7, 5, 4], [3, 4, 5]]]),
(-1, 1, 2): ([[[-1, 8, 6], [7, 5, 4], [3, 4, 5]]]),
(-1, 2, 3): ([[[-1, -1, 8], [-1, 8, 6], [7, 5, 4], [3, 4, 5]]]), (0, 0, 0): ([[7, 5, 4]]),
(0, 1, 1): ([[[8, 6, -1], [7, 5, 4]]]), (0, 2, 2): ([[[-1, -1, 8], [-1, 8, 6], [7, 5, 4]]]),
(1, 1, 2): ([[8, 6]]), (1, 2, 3): ([[[-1, 8], [8, 6]]]), (2, 2, 0): ([[8]])}),
(2, 1, 1): (([[[5]], [[6]]]), {}), (2, 3, 3): (\
([[[7, 6, 3], [5, 8, 5], [5, 0, 2]], [[1, 8, 1], [5, 5, 8], [8, 4, 0]]]),
{(-2, -2, 0): ([[5], [8]]), (-2, -1, 3): ([[[5, 0], [-1, 5]], [[5, 4], [-1, 8]]]),
(-2, 0, 2): ([[[7, 8, 2], [5, 0, -1], [5, -1, -1]], [[1, 5, 0], [5, 4, -1], [8, -1, -1]]]),
(-2, 1, 3): ([[[-1, 6, 5], [7, 8, 2], [-1, 5, 0], [-1, -1, 5]], [[-1, 8, 8], [1, 5, 0], [-1, 5, 4], [-1, -1, 8]]]),
(-2, 2, 0): ([[[3, -1, -1], [6, 5, -1], [7, 8, 2], [-1, 5, 0], [-1, -1, 5]],
[[1, -1, -1], [8, 8, -1], [1, 5, 0], [-1, 5, 4], [-1, -1, 8]]]), (-1, -1, 2): ([[5, 0], [5, 4]]),
(-1, 0, 1): ([[[7, 8, 2], [5, 0, -1]], [[1, 5, 0], [5, 4, -1]]]),
(-1, 1, 2): ([[[-1, 6, 5], [7, 8, 2], [5, 0, -1]], [[-1, 8, 8], [1, 5, 0], [5, 4, -1]]]),
(-1, 2, 3): ([[[-1, -1, 3], [-1, 6, 5], [7, 8, 2], [-1, 5, 0]], [[-1, -1, 1], [-1, 8, 8], [1, 5, 0], [-1, 5, 4]]]),
(0, 0, 0): ([[7, 8, 2], [1, 5, 0]]), (0, 1, 1): ([[[6, 5, -1], [7, 8, 2]], [[8, 8, -1], [1, 5, 0]]]),
(0, 2, 2): ([[[-1, -1, 3], [-1, 6, 5], [7, 8, 2]], [[-1, -1, 1], [-1, 8, 8], [1, 5, 0]]]),
(1, 1, 2): ([[6, 5], [8, 8]]), (1, 2, 3): ([[[-1, 3], [6, 5]], [[-1, 1], [8, 8]]])}),
(2, 1, 2): (([[[6, 3]], [[5, 5]]]), {}), (2, 3, 5): (\
([[[1, 2, 1, 2, 7], [0, 3, 5, 0, 2], [0, 5, 1, 7, 5]], [[3, 4, 3, 5, 7], [2, 5, 2, 7, 5], [7, 5, 1, 1, 7]]]),
{(-2, -2, 0): ([[0], [7]]), (-2, -1, 3): ([[[0, 5], [-1, 0]], [[2, 5], [-1, 7]]]),
(-2, 0, 2): ([[[1, 3, 1], [0, 5, -1], [0, -1, -1]], [[3, 5, 1], [2, 5, -1], [7, -1, -1]]]),
(-2, 1, 3): ([[[2, 5, 7], [1, 3, 1], [-1, 0, 5], [-1, -1, 0]], [[4, 2, 1], [3, 5, 1], [-1, 2, 5], [-1, -1, 7]]]),
(-2, 2, 0): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5], [-1, -1, 0]],
[[3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5], [-1, -1, 7]]]), (-2, 3, 1): (\
[[[2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1], [0, -1, -1]],
[[5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1], [7, -1, -1]]]), (-2, 4, 2): (\
[[[-1, -1, 7], [-1, 2, 2], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1], [0, -1, -1]],
[[-1, -1, 7], [-1, 5, 5], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1], [7, -1, -1]]]),
(-1, -1, 2): ([[0, 5], [2, 5]]), (-1, 0, 1): ([[[1, 3, 1], [0, 5, -1]], [[3, 5, 1], [2, 5, -1]]]),
(-1, 1, 2): ([[[2, 5, 7], [1, 3, 1], [0, 5, -1]], [[4, 2, 1], [3, 5, 1], [2, 5, -1]]]),
(-1, 2, 3): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5]], [[3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5]]]),
(-1, 3, 0): ([[[2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5]],
[[5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5]]]), (-1, 4, 1): (\
[[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1]],
[[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1]]]), (0, 0, 0): ([[1, 3, 1], [3, 5, 1]]),
(0, 1, 1): ([[[2, 5, 7], [1, 3, 1]], [[4, 2, 1], [3, 5, 1]]]),
(0, 2, 2): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1]], [[3, 7, 7], [4, 2, 1], [3, 5, 1]]]),
(0, 3, 3): ([[[-1, 2, 2], [1, 0, 5], [2, 5, 7], [1, 3, 1]], [[-1, 5, 5], [3, 7, 7], [4, 2, 1], [3, 5, 1]]]),
(0, 4, 0): ([[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1]],
[[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1]]]), (1, 1, 2): ([[2, 5, 7], [4, 2, 1]]),
(1, 2, 3): ([[[1, 0, 5], [2, 5, 7]], [[3, 7, 7], [4, 2, 1]]]),
(1, 3, 0): ([[[2, 2, -1], [1, 0, 5], [2, 5, 7]], [[5, 5, -1], [3, 7, 7], [4, 2, 1]]]),
(1, 4, 1): ([[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7]], [[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1]]])}),
(2, 2, 1): (([[[4], [8]], [[3], [5]]]),
{(-1, -1, 2): ([[8], [5]]), (-1, 0, 1): ([[[4], [8]], [[3], [5]]]), (0, 0, 0): ([[4], [3]])}),
(2, 5, 3): (([[[6, 8, 5], [7, 2, 7], [2, 2, 5], [5, 6, 7], [5, 0, 2]],
[[3, 8, 7], [7, 8, 2], [8, 1, 0], [0, 6, 5], [6, 3, 1]]]),
{(-4, -4, 0): ([[5], [6]]), (-4, -3, 3): ([[[5, 0], [-1, 5]], [[0, 3], [-1, 6]]]),
(-4, -2, 2): ([[[2, 6, 2], [5, 0, -1], [5, -1, -1]], [[8, 6, 1], [0, 3, -1], [6, -1, -1]]]),
(-4, -1, 1): ([[[7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
[[7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), (-4, 0, 0): (\
[[[6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0], [-1, -1, 5]],
[[3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3], [-1, -1, 6]]]), (-4, 1, 1): (\
[[[8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
[[8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), (-4, 2, 2): (\
[[[-1, -1, 5], [-1, 8, 7], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
[[-1, -1, 7], [-1, 8, 2], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]),
(-3, -3, 2): ([[5, 0], [0, 3]]),
(-3, -2, 1): ([[[2, 6, 2], [5, 0, -1]], [[8, 6, 1], [0, 3, -1]]]),
(-3, -1, 0): ([[[7, 2, 7], [2, 6, 2], [-1, 5, 0]], [[7, 1, 5], [8, 6, 1], [-1, 0, 3]]]),
(-3, 0, 3): (\
[[[6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0]], [[3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3]]]),
(-3, 1, 0): ([[[8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0]],
[[8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3]]]), (-3, 2, 1): (\
[[[5, -1, -1], [8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1]],
[[7, -1, -1], [8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1]]]),
(-2, -2, 0): ([[2, 6, 2], [8, 6, 1]]),
(-2, -1, 3): ([[[7, 2, 7], [2, 6, 2]], [[7, 1, 5], [8, 6, 1]]]),
(-2, 0, 2): ([[[6, 2, 5], [7, 2, 7], [2, 6, 2]], [[3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
(-2, 1, 3): (\
[[[-1, 8, 7], [6, 2, 5], [7, 2, 7], [2, 6, 2]], [[-1, 8, 2], [3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
(-2, 2, 0): ([[[5, -1, -1], [8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2]],
[[7, -1, -1], [8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
(-1, -1, 2): ([[7, 2, 7], [7, 1, 5]]),
(-1, 0, 1): ([[[6, 2, 5], [7, 2, 7]], [[3, 8, 0], [7, 1, 5]]]),
(-1, 1, 2): ([[[-1, 8, 7], [6, 2, 5], [7, 2, 7]], [[-1, 8, 2], [3, 8, 0], [7, 1, 5]]]),
(-1, 2, 3): ([[[-1, -1, 5], [-1, 8, 7], [6, 2, 5], [7, 2, 7]],
[[-1, -1, 7], [-1, 8, 2], [3, 8, 0], [7, 1, 5]]]),
(0, 0, 0): ([[6, 2, 5], [3, 8, 0]]),
(0, 1, 1): ([[[8, 7, -1], [6, 2, 5]], [[8, 2, -1], [3, 8, 0]]]),
(0, 2, 2): ([[[-1, -1, 5], [-1, 8, 7], [6, 2, 5]], [[-1, -1, 7], [-1, 8, 2], [3, 8, 0]]]),
(1, 1, 2): ([[8, 7], [8, 2]]), (1, 2, 3): ([[[-1, 5], [8, 7]], [[-1, 7], [8, 2]]]),
(2, 2, 0): ([[5], [7]])})}
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_matrix_diag_part_net_cpu(): @pytest.mark.parametrize('array_dict', [([[[5]]], {}),
([[[3, 1, 1], [6, 4, 4], [1, 6, 4]]],
{(-2, -2, 0): [[1]], (-2, -1, 3): [[[6, 6], [-1, 1]]],
(-2, 0, 2): [[[3, 4, 4], [6, 6, -1], [1, -1, -1]]],
(-2, 1, 3): [[[-1, 1, 4], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]],
(-2, 2, 0): [[[1, -1, -1], [1, 4, -1], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]],
(-1, -1, 2): [[6, 6]], (-1, 0, 1): [[[3, 4, 4], [6, 6, -1]]],
(-1, 1, 2): [[[-1, 1, 4], [3, 4, 4], [6, 6, -1]]],
(-1, 2, 3): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4], [-1, 6, 6]]],
(0, 0, 0): [[3, 4, 4]], (0, 1, 1): [[[1, 4, -1], [3, 4, 4]]],
(0, 2, 2): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4]]], (1, 1, 2): [[1, 4]],
(1, 2, 3): [[[-1, 1], [1, 4]]]}),
([[[6, 1]]], {}),
([[[2, 2, 4, 3, 0], [8, 5, 3, 0, 3], [6, 3, 2, 6, 7]]],
{(-2, -2, 0): [[6]], (-2, -1, 3): [[[8, 3], [-1, 6]]],
(-2, 0, 2): [[[2, 5, 2], [8, 3, -1], [6, -1, -1]]],
(-2, 1, 3): [[[2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]],
(-2, 2, 0): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]],
(-2, 3, 1): [
[[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1], [6, -1, -1]]],
(-2, 4, 2): [
[[-1, -1, 0], [-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1],
[6, -1, -1]]], (-1, -1, 2): [[8, 3]],
(-1, 0, 1): [[[2, 5, 2], [8, 3, -1]]],
(-1, 1, 2): [[[2, 3, 6], [2, 5, 2], [8, 3, -1]]],
(-1, 2, 3): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]],
(-1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]],
(-1, 4, 1): [
[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1]]],
(0, 0, 0): [[2, 5, 2]], (0, 1, 1): [[[2, 3, 6], [2, 5, 2]]],
(0, 2, 2): [[[4, 0, 7], [2, 3, 6], [2, 5, 2]]],
(0, 3, 3): [[[-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2]]],
(0, 4, 0): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2]]],
(1, 1, 2): [[2, 3, 6]], (1, 2, 3): [[[4, 0, 7], [2, 3, 6]]],
(1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6]]],
(1, 4, 1): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6]]]}),
([[[5], [5]]], {(-1, -1, 2): [[5]], (-1, 0, 1): [[[5], [5]]],
(0, 0, 0): [[5]]}),
([[[2, 4, 1], [6, 4, 1], [0, 5, 2], [1, 6, 0], [1, 0, 7]]],
{(-4, -4, 0): [[1]], (-4, -3, 3): [[[1, 0], [-1, 1]]],
(-4, -2, 2): [[[0, 6, 7], [1, 0, -1], [1, -1, -1]]],
(-4, -1, 1): [[[6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]],
(-4, 0, 0): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0], [-1, -1, 1]]],
(-4, 1, 1): [
[[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]],
(-4, 2, 2): [
[[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1],
[1, -1, -1]]], (-3, -3, 2): [[1, 0]],
(-3, -2, 1): [[[0, 6, 7], [1, 0, -1]]],
(-3, -1, 0): [[[6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
(-3, 0, 3): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
(-3, 1, 0): [[[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
(-3, 2, 1): [
[[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1]]],
(-2, -2, 0): [[0, 6, 7]], (-2, -1, 3): [[[6, 5, 0], [0, 6, 7]]],
(-2, 0, 2): [[[2, 4, 2], [6, 5, 0], [0, 6, 7]]],
(-2, 1, 3): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]],
(-2, 2, 0): [[[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]],
(-1, -1, 2): [[6, 5, 0]], (-1, 0, 1): [[[2, 4, 2], [6, 5, 0]]],
(-1, 1, 2): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0]]],
(-1, 2, 3): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0]]],
(0, 0, 0): [[2, 4, 2]], (0, 1, 1): [[[4, 1, -1], [2, 4, 2]]],
(0, 2, 2): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2]]], (1, 1, 2): [[4, 1]],
(1, 2, 3): [[[-1, 1], [4, 1]]], (2, 2, 0): [[1]]}),
([[[6]], [[4]]], {}),
([[[2, 4, 8], [3, 4, 2], [1, 6, 3]], [[6, 7, 2], [8, 2, 1], [4, 5, 5]]],
{(-2, -2, 0): [[1], [4]], (-2, -1, 3): [[[3, 6], [-1, 1]], [[8, 5], [-1, 4]]],
(-2, 0, 2): [[[2, 4, 3], [3, 6, -1], [1, -1, -1]],
[[6, 2, 5], [8, 5, -1], [4, -1, -1]]],
(-2, 1, 3): [[[-1, 4, 2], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]],
[[-1, 7, 1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]],
(-2, 2, 0): [[[8, -1, -1], [4, 2, -1], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]],
[[2, -1, -1], [7, 1, -1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]],
(-1, -1, 2): [[3, 6], [8, 5]],
(-1, 0, 1): [[[2, 4, 3], [3, 6, -1]], [[6, 2, 5], [8, 5, -1]]],
(-1, 1, 2): [[[-1, 4, 2], [2, 4, 3], [3, 6, -1]],
[[-1, 7, 1], [6, 2, 5], [8, 5, -1]]],
(-1, 2, 3): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3], [-1, 3, 6]],
[[-1, -1, 2], [-1, 7, 1], [6, 2, 5], [-1, 8, 5]]],
(0, 0, 0): [[2, 4, 3], [6, 2, 5]],
(0, 1, 1): [[[4, 2, -1], [2, 4, 3]], [[7, 1, -1], [6, 2, 5]]],
(0, 2, 2): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3]],
[[-1, -1, 2], [-1, 7, 1], [6, 2, 5]]],
(1, 1, 2): [[4, 2], [7, 1]],
(1, 2, 3): [[[-1, 8], [4, 2]], [[-1, 2], [7, 1]]]}),
([[[4, 0]], [[7, 4]]], {}),
([[[3, 5, 8, 3, 5], [7, 8, 1, 0, 6], [5, 4, 0, 3, 6]],
[[7, 4, 8, 7, 3], [4, 6, 5, 7, 1], [5, 3, 1, 1, 0]]],
{(-2, -2, 0): [[5], [5]], (-2, -1, 3): [[[7, 4], [-1, 5]], [[4, 3], [-1, 5]]],
(-2, 0, 2): [[[3, 8, 0], [7, 4, -1], [5, -1, -1]],
[[7, 6, 1], [4, 3, -1], [5, -1, -1]]],
(-2, 1, 3): [[[5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]],
[[4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]],
(-2, 2, 0): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]],
[[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]],
(-2, 3, 1): [
[[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1], [5, -1, -1]],
[[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1], [5, -1, -1]]],
(-2, 4, 2): [
[[-1, -1, 5], [-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1],
[5, -1, -1]],
[[-1, -1, 3], [-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1],
[5, -1, -1]]],
(-1, -1, 2): [[7, 4], [4, 3]],
(-1, 0, 1): [[[3, 8, 0], [7, 4, -1]], [[7, 6, 1], [4, 3, -1]]],
(-1, 1, 2): [[[5, 1, 3], [3, 8, 0], [7, 4, -1]],
[[4, 5, 1], [7, 6, 1], [4, 3, -1]]],
(-1, 2, 3): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]],
[[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]],
(-1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]],
[[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]],
(-1, 4, 1): [
[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1]],
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1]]],
(0, 0, 0): [[3, 8, 0], [7, 6, 1]],
(0, 1, 1): [[[5, 1, 3], [3, 8, 0]], [[4, 5, 1], [7, 6, 1]]],
(0, 2, 2): [[[8, 0, 6], [5, 1, 3], [3, 8, 0]],
[[8, 7, 0], [4, 5, 1], [7, 6, 1]]],
(0, 3, 3): [[[-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0]],
[[-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]],
(0, 4, 0): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0]],
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]],
(1, 1, 2): [[5, 1, 3], [4, 5, 1]],
(1, 2, 3): [[[8, 0, 6], [5, 1, 3]], [[8, 7, 0], [4, 5, 1]]],
(1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3]],
[[7, 1, -1], [8, 7, 0], [4, 5, 1]]],
(1, 4, 1): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3]],
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1]]]}),
([[[4], [7]], [[3], [5]]],
{(-1, -1, 2): [[7], [5]], (-1, 0, 1): [[[4], [7]], [[3], [5]]],
(0, 0, 0): [[4], [3]]}),
([[[0, 2, 2], [0, 0, 5], [6, 5, 5], [5, 8, 5], [3, 8, 0]],
[[2, 8, 3], [4, 4, 1], [0, 4, 2], [0, 7, 0], [0, 7, 4]]],
{(-4, -4, 0): [[3], [0]], (-4, -3, 3): [[[5, 8], [-1, 3]], [[0, 7], [-1, 0]]],
(-4, -2, 2): [[[6, 8, 0], [5, 8, -1], [3, -1, -1]],
[[0, 7, 4], [0, 7, -1], [0, -1, -1]]],
(-4, -1, 1): [[[0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]],
[[4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]],
(-4, 0, 0): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8], [-1, -1, 3]],
[[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7], [-1, -1, 0]]],
(-4, 1, 1): [
[[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]],
[[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]],
(-4, 2, 2): [
[[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1],
[3, -1, -1]],
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1],
[0, -1, -1]]], (-3, -3, 2): [[5, 8], [0, 7]],
(-3, -2, 1): [[[6, 8, 0], [5, 8, -1]], [[0, 7, 4], [0, 7, -1]]],
(-3, -1, 0): [[[0, 5, 5], [6, 8, 0], [-1, 5, 8]],
[[4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
(-3, 0, 3): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]],
[[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
(-3, 1, 0): [[[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]],
[[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
(-3, 2, 1): [
[[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1]],
[[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1]]],
(-2, -2, 0): [[6, 8, 0], [0, 7, 4]],
(-2, -1, 3): [[[0, 5, 5], [6, 8, 0]], [[4, 4, 0], [0, 7, 4]]],
(-2, 0, 2): [[[0, 0, 5], [0, 5, 5], [6, 8, 0]],
[[2, 4, 2], [4, 4, 0], [0, 7, 4]]],
(-2, 1, 3): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0]],
[[-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]],
(-2, 2, 0): [[[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0]],
[[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]],
(-1, -1, 2): [[0, 5, 5], [4, 4, 0]],
(-1, 0, 1): [[[0, 0, 5], [0, 5, 5]], [[2, 4, 2], [4, 4, 0]]],
(-1, 1, 2): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5]],
[[-1, 8, 1], [2, 4, 2], [4, 4, 0]]],
(-1, 2, 3): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5]],
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0]]],
(0, 0, 0): [[0, 0, 5], [2, 4, 2]],
(0, 1, 1): [[[2, 5, -1], [0, 0, 5]], [[8, 1, -1], [2, 4, 2]]],
(0, 2, 2): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5]],
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2]]],
(1, 1, 2): [[2, 5], [8, 1]],
(1, 2, 3): [[[-1, 2], [2, 5]], [[-1, 3], [8, 1]]], (2, 2, 0): [[2], [3]]})])
def test_matrix_diag_part_net(array_dict):
""" """
testcase generate from below testcase generate from below
from tensorflow.python.ops import array_ops from tf.python.ops import array_ops
import numpy as np import numpy as np
f = open (r'dict.tst','w') f = open (r'dict.tst','w')
aligndict = {0: "LEFT_RIGHT", 1:"LEFT_LEFT", 2:"RIGHT_LEFT", 3:"RIGHT_RIGHT"} aligndict = {0: "LEFT_RIGHT", 1:"LEFT_LEFT", 2:"RIGHT_LEFT", 3:"RIGHT_RIGHT"}
Adict={} Adict=[]
for i in [1, 2]: for i in [1, 2]:
for m,n in [(1, 1), (3,3),(1, 2),(3, 5),(2, 1),(5, 3)]: for m,n in [(1, 1), (3,3),(1, 2),(3, 5),(2, 1),(5, 3)]:
A = np.array(np.random.randint(20, size=(i, m, n))) A = np.array(np.random.randint(20, size=(i, m, n)))
Adict[i,m,n]=A
p = -1
kadict={} kadict={}
for k0 in range(-m + 1, m - 1): for k0 in range(-m + 1, m - 1):
for k1 in range(k0, n): for k1 in range(k0, n):
@ -187,7 +223,7 @@ def test_matrix_diag_part_net_cpu():
ka = (k,align_) ka = (k,align_)
B = array_ops.matrix_diag_part(A, k=k, align=aligndict[align_], padding_value=-1) B = array_ops.matrix_diag_part(A, k=k, align=aligndict[align_], padding_value=-1)
kadict[ka] = B.numpy() kadict[ka] = B.numpy()
Adict[i,m,n]=(A, kadict) Adict.append(A, kadict)
print(Adict, file= f) print(Adict, file= f)
f.close() f.close()
Feature: ALL To ALL Feature: ALL To ALL
@ -195,12 +231,11 @@ def test_matrix_diag_part_net_cpu():
Expectation: the result match to numpy Expectation: the result match to numpy
""" """
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(mode=context.PYNATIVE_MODE)
for _, value in Adict.items(): a, kadict = array_dict
a, kadict = value for key1, b in kadict.items():
for key1, b in kadict.items(): k0, k1, align_ = key1
k0, k1, align_ = key1 if k0 == k1:
if k0 == k1: r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), k0, PAD_VALUE, align=aligndict[align_])
r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), k0, PAD_VALUE, align=aligndict[align_]) else:
else: r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), (k0, k1), PAD_VALUE, align=aligndict[align_])
r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), (k0, k1), PAD_VALUE, align=aligndict[align_]) match_array(b, r_b.asnumpy())
match_matrix(Tensor(b), Tensor(r_b))