forked from mindspore-Ecosystem/mindspore
!29625 Add GPU matrix_diag_part ops, optimize related testcase
Merge pull request !29625 from wuwenbing/dev
This commit is contained in:
commit
776d937266
|
@ -41,8 +41,8 @@ __global__ void MatrixBandPartKernel(const size_t size, const T *input_matrix_ad
|
|||
template <typename T>
|
||||
void MatrixBandPart(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n, const int64_t l,
|
||||
const int64_t u, T *output_addr, cudaStream_t cuda_stream) {
|
||||
MatrixBandPartKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_matrix_addr, m, n, l, u,
|
||||
output_addr, cuda_stream);
|
||||
MatrixBandPartKernel<<<GET_BLOCKS(size), GET_THREADS_MAXSIZE(size), 0, cuda_stream>>>(size, input_matrix_addr, m, n,
|
||||
l, u, output_addr, cuda_stream);
|
||||
}
|
||||
|
||||
template void MatrixBandPart<int32_t>(const size_t size, const int32_t *input_matrix_addr, const size_t m,
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "matrix_diag_part_impl.cuh"
|
||||
#include <cuda_runtime.h>
|
||||
#include <algorithm>
|
||||
#include "utils/complex.h"
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
template <typename T>
|
||||
__global__ void MatrixDiagPartKernel(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n,
|
||||
const int64_t l, const int64_t u, const size_t num_diags,
|
||||
const size_t max_diag_len, const int64_t la, const int64_t ua, T *padding_value,
|
||||
T *output_addr, cudaStream_t cuda_stream) {
|
||||
int64_t dest_inner_matrix_len = num_diags * max_diag_len;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
const int64_t i = pos / dest_inner_matrix_len;
|
||||
const int64_t j = u - (pos % dest_inner_matrix_len) / max_diag_len;
|
||||
const int64_t k = (pos % dest_inner_matrix_len) % max_diag_len;
|
||||
int64_t current_diag_len = j >= 0 ? min(n - j, m) : min(m + j, n);
|
||||
int64_t current_pad_len = max_diag_len - current_diag_len;
|
||||
// Pad left by default (0:right, 1:left)
|
||||
bool pad_left = (la == 0 && j > 0) || (ua == 0 && j < 0);
|
||||
// Set none-padding values, l means current diag col index
|
||||
// Source pos, k offset, only effective when pad left
|
||||
int64_t k_offset = (pad_left && k >= current_pad_len) ? k - current_pad_len : k;
|
||||
|
||||
// Calculate source offset row/col offset
|
||||
size_t row_index = j >= 0 ? j + k_offset : k_offset;
|
||||
size_t col_index = j >= 0 ? k_offset : k_offset - j;
|
||||
size_t source_offset = i * m * n + col_index * n + row_index;
|
||||
// If current pos need pad, then the value is pad value
|
||||
bool current_pad_flag = (pad_left && k < current_pad_len) || (!pad_left && k >= current_diag_len);
|
||||
T current_pad_value = current_pad_flag ? *padding_value : *(input_matrix_addr + source_offset);
|
||||
int64_t j_index = u - j;
|
||||
size_t dest_offset = dest_inner_matrix_len * i + j_index * max_diag_len + k;
|
||||
*(output_addr + dest_offset) = current_pad_value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixDiagPart(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n, const int64_t l,
|
||||
const int64_t u, const size_t num_diags, const size_t max_diag_len, const int64_t la,
|
||||
const int64_t ua, T *padding_value, T *output_addr, cudaStream_t cuda_stream) {
|
||||
MatrixDiagPartKernel<<<GET_BLOCKS(size), GET_THREADS_MAXSIZE(size), 0, cuda_stream>>>(
|
||||
size, input_matrix_addr, m, n, l, u, num_diags, max_diag_len, la, ua, padding_value, output_addr, cuda_stream);
|
||||
}
|
||||
|
||||
template void MatrixDiagPart<int32_t>(const size_t size, const int32_t *input_matrix_addr, const size_t m,
|
||||
const size_t n, const int64_t l, const int64_t u, const size_t num_diags,
|
||||
const size_t max_diag_len, const int64_t la, const int64_t ua,
|
||||
int32_t *padding_value, int32_t *output_addr, cudaStream_t cuda_stream);
|
||||
template void MatrixDiagPart<int64_t>(const size_t size, const int64_t *input_matrix_addr, const size_t m,
|
||||
const size_t n, const int64_t l, const int64_t u, const size_t num_diags,
|
||||
const size_t max_diag_len, const int64_t la, const int64_t ua,
|
||||
int64_t *padding_value, int64_t *output_addr, cudaStream_t cuda_stream);
|
||||
template void MatrixDiagPart<float>(const size_t size, const float *input_matrix_addr, const size_t m, const size_t n,
|
||||
const int64_t l, const int64_t u, const size_t num_diags, const size_t max_diag_len,
|
||||
const int64_t la, const int64_t ua, float *padding_value, float *output_addr,
|
||||
cudaStream_t cuda_stream);
|
||||
template void MatrixDiagPart<double>(const size_t size, const double *input_matrix_addr, const size_t m, const size_t n,
|
||||
const int64_t l, const int64_t u, const size_t num_diags,
|
||||
const size_t max_diag_len, const int64_t la, const int64_t ua,
|
||||
double *padding_value, double *output_addr, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_MATRIX_DIAG_PART_IMPL_CUH
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_MATRIX_DIAG_PART_IMPL_CUH
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void MatrixDiagPart(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n, const int64_t l,
|
||||
const int64_t u, const size_t num_diags, const size_t max_diag_len, const int64_t la,
|
||||
const int64_t ua, T *padding_value, T *output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_MATRIX_DIAG_PART_IMPL_CUH
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/matrix_diag_part_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixDiagPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MatrixDiagPartGpuKernelMod, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixDiagPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MatrixDiagPartGpuKernelMod, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixDiagPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MatrixDiagPartGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixDiagPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MatrixDiagPartGpuKernelMod, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,146 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_DIAG_PART_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_DIAG_PART_GPU_KERNEL_H
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusolverDn.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "utils/complex.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/matrix_diag_part_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
template <typename T>
|
||||
class MatrixDiagPartGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
MatrixDiagPartGpuKernelMod() : is_null_input_(false) { ResetResource(); }
|
||||
|
||||
~MatrixDiagPartGpuKernelMod() = default;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
dim_size_ = shapes_.size();
|
||||
if (shapes_.size() < kDim2) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, matrix shape should not less than 2.";
|
||||
}
|
||||
m_ = shapes_[dim_size_ - kDim2];
|
||||
n_ = shapes_[dim_size_ - kDim1];
|
||||
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
|
||||
out_range_size_ *= shapes_[i];
|
||||
}
|
||||
matrix_size_ = out_range_size_ * m_ * n_;
|
||||
InitSizeLists();
|
||||
alignment_ = GetAlignments(AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAlignment));
|
||||
kernel_node_ = kernel_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
void PostExecute() override {
|
||||
auto output_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node_.lock(), 0);
|
||||
output_shape[shapes_.size() - kDim1] = max_diag_len_;
|
||||
// If the out shape m' * n', the m' dimension is 1, then remove this dimension
|
||||
output_shape[shapes_.size() - kDim2] = num_diags_;
|
||||
if (num_diags_ == 1) {
|
||||
output_shape.erase(output_shape.begin() + shapes_.size() - kDim2);
|
||||
}
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node_.lock(), 0);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({data_type}, {output_shape}, kernel_node_.lock().get());
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto input_matrix_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto d_k_range = GetDeviceAddress<int64_t>(inputs, kDim1);
|
||||
auto padding_value = GetDeviceAddress<T>(inputs, kDim2);
|
||||
auto output_matrix_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
|
||||
int64_t k_range[kDim2]{0, 0};
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&k_range, d_k_range, kDim2 * sizeof(int64_t), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy input lower to host failed");
|
||||
int64_t l = k_range[0];
|
||||
int64_t u = k_range[1];
|
||||
// New diagonal matrix m*n matrix, m dimension ;
|
||||
if (l > u) {
|
||||
MS_LOG(EXCEPTION) << "The k[1] must not less than k[0].";
|
||||
}
|
||||
u = std::min(u, static_cast<int64_t>(n_) - 1);
|
||||
l = std::max(-(static_cast<int64_t>(m_) - 1), l);
|
||||
num_diags_ = u - l + 1;
|
||||
// New diagonal matrix m * n matrix, n dimension
|
||||
max_diag_len_ = std::min(m_ + std::min(u, static_cast<int64_t>(0)), n_ + std::min(-l, static_cast<int64_t>(0)));
|
||||
MatrixDiagPart(out_range_size_ * num_diags_ * max_diag_len_, input_matrix_addr, m_, n_, l, u, num_diags_,
|
||||
max_diag_len_, alignment_.first, alignment_.second, padding_value, output_matrix_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(matrix_size_ * sizeof(T)); // Input
|
||||
input_size_list_.push_back(kDim2 * sizeof(int64_t)); // k_range
|
||||
input_size_list_.push_back(sizeof(T)); // padding_value
|
||||
output_size_list_.push_back(matrix_size_ * sizeof(T)); // Output
|
||||
}
|
||||
|
||||
private:
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
bool is_null_input_;
|
||||
std::vector<size_t> shapes_{};
|
||||
size_t dim_size_{1};
|
||||
size_t matrix_size_{0};
|
||||
size_t out_range_size_{1};
|
||||
int64_t num_diags_{1};
|
||||
int64_t max_diag_len_{1};
|
||||
size_t m_{1};
|
||||
size_t n_{1};
|
||||
std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> alignment_{MatrixDiag::RIGHT, MatrixDiag::LEFT};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_DIAG_PART_GPU_KERNEL_H
|
|
@ -30,6 +30,7 @@ namespace gpu {
|
|||
class CudaCommon {
|
||||
public:
|
||||
inline int threads_num() const { return threads_per_block_; }
|
||||
inline int threads_num(int size) const { return std::min(size, threads_per_block_); }
|
||||
inline int major_sm() const { return major_sm_; }
|
||||
inline float cuda_cap() const { return static_cast<float>(major_sm_ * 10 + minor_sm_) / 10.0; }
|
||||
inline int blocks_num(const int total_threads) const {
|
||||
|
@ -68,6 +69,7 @@ class CudaCommon {
|
|||
};
|
||||
#define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads)
|
||||
#define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num()
|
||||
#define GET_THREADS_MAXSIZE(size) mindspore::device::gpu::CudaCommon::GetInstance().threads_num(size)
|
||||
#define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm()
|
||||
#define GET_CUDA_CAP mindspore::device::gpu::CudaCommon::GetInstance().cuda_cap()
|
||||
#define SHARED_MEM_PER_BLOCK mindspore::device::gpu::CudaCommon::GetInstance().share_memory_size()
|
||||
|
|
|
@ -16,169 +16,205 @@
|
|||
import pytest
|
||||
import mindspore.scipy as msp
|
||||
from mindspore import context, Tensor
|
||||
from tests.st.scipy_st.utils import match_matrix
|
||||
from tests.st.scipy_st.utils import match_array
|
||||
|
||||
aligndict = {0: "LEFT_RIGHT", 1: "LEFT_LEFT", 2: "RIGHT_LEFT", 3: "RIGHT_RIGHT"}
|
||||
PAD_VALUE = -1
|
||||
Adict = {(1, 1, 1): (([[[1]]]), {}), (1, 3, 3): (([[[8, 2, 1], [5, 3, 7], [0, 3, 4]]]),
|
||||
{(-2, -2, 0): ([[0]]), (-2, -1, 3): ([[[5, 3], [-1, 0]]]),
|
||||
(-2, 0, 2): ([[[8, 3, 4], [5, 3, -1], [0, -1, -1]]]),
|
||||
(-2, 1, 3): ([[[-1, 2, 7], [8, 3, 4], [-1, 5, 3], [-1, -1, 0]]]),
|
||||
(-2, 2, 0): (\
|
||||
[[[1, -1, -1], [2, 7, -1], [8, 3, 4], [-1, 5, 3], [-1, -1, 0]]]),
|
||||
(-1, -1, 2): ([[5, 3]]), (-1, 0, 1): ([[[8, 3, 4], [5, 3, -1]]]),
|
||||
(-1, 1, 2): ([[[-1, 2, 7], [8, 3, 4], [5, 3, -1]]]),
|
||||
(-1, 2, 3): ([[[-1, -1, 1], [-1, 2, 7], [8, 3, 4], [-1, 5, 3]]]),
|
||||
(0, 0, 0): ([[8, 3, 4]]), (0, 1, 1): ([[[2, 7, -1], [8, 3, 4]]]),
|
||||
(0, 2, 2): ([[[-1, -1, 1], [-1, 2, 7], [8, 3, 4]]]),
|
||||
(1, 1, 2): ([[2, 7]]), (1, 2, 3): ([[[-1, 1], [2, 7]]])}),
|
||||
(1, 1, 2): (([[[3, 2]]]), {}), (1, 3, 5): (([[[3, 5, 5, 2, 5], [0, 6, 2, 4, 7], [7, 3, 3, 6, 8]]]),
|
||||
{(-2, -2, 0): ([[7]]), (-2, -1, 3): ([[[0, 3], [-1, 7]]]),
|
||||
(-2, 0, 2): ([[[3, 6, 3], [0, 3, -1], [7, -1, -1]]]),
|
||||
(-2, 1, 3): ([[[5, 2, 6], [3, 6, 3], [-1, 0, 3], [-1, -1, 7]]]),
|
||||
(-2, 2, 0): (\
|
||||
[[[5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3], [-1, -1, 7]]]),
|
||||
(-2, 3, 1): ([
|
||||
[[2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], [0, 3, -1],
|
||||
[7, -1, -1]]]), (-2, 4, 2): ([\
|
||||
[[-1, -1, 5], [-1, 2, 7], [5, 4, 8], [5, 2, 6], [3, 6, 3],
|
||||
[0, 3, -1], [7, -1, -1]]]), (-1, -1, 2): ([[0, 3]]),
|
||||
(-1, 0, 1): ([[[3, 6, 3], [0, 3, -1]]]),
|
||||
(-1, 1, 2): ([[[5, 2, 6], [3, 6, 3], [0, 3, -1]]]),
|
||||
(-1, 2, 3): ([[[5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3]]]),
|
||||
(-1, 3, 0): (\
|
||||
[[[2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3]]]),
|
||||
(-1, 4, 1): ([
|
||||
[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3],
|
||||
[0, 3, -1]]]), (0, 0, 0): ([[3, 6, 3]]),
|
||||
(0, 1, 1): ([[[5, 2, 6], [3, 6, 3]]]),
|
||||
(0, 2, 2): ([[[5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
|
||||
(0, 3, 3): ([[[-1, 2, 7], [5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
|
||||
(0, 4, 0): (\
|
||||
[[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
|
||||
(1, 1, 2): ([[5, 2, 6]]), (1, 2, 3): ([[[5, 4, 8], [5, 2, 6]]]),
|
||||
(1, 3, 0): ([[[2, 7, -1], [5, 4, 8], [5, 2, 6]]]),
|
||||
(1, 4, 1): ([[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6]]])}),
|
||||
(1, 2, 1): (([[[4], [1]]]), {(-1, -1, 2): ([[1]]), (-1, 0, 1): ([[[4], [1]]]), (0, 0, 0): ([[4]])}),
|
||||
(1, 5, 3): (([[[7, 8, 8], [3, 5, 6], [0, 4, 4], [8, 4, 5], [0, 4, 6]]]),
|
||||
{(-4, -4, 0): ([[0]]), (-4, -3, 3): ([[[8, 4], [-1, 0]]]),
|
||||
(-4, -2, 2): ([[[0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
|
||||
(-4, -1, 1): ([[[3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
|
||||
(-4, 0, 0): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4], [-1, -1, 0]]]),
|
||||
(-4, 1, 1): ([[[8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
|
||||
(-4, 2, 2): (\
|
||||
[[[-1, -1, 8], [-1, 8, 6], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
|
||||
(-3, -3, 2): ([[8, 4]]), (-3, -2, 1): ([[[0, 4, 6], [8, 4, -1]]]),
|
||||
(-3, -1, 0): ([[[3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
|
||||
(-3, 0, 3): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
|
||||
(-3, 1, 0): ([[[8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
|
||||
(-3, 2, 1): ([[[8, -1, -1], [8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1]]]),
|
||||
(-2, -2, 0): ([[0, 4, 6]]), (-2, -1, 3): ([[[3, 4, 5], [0, 4, 6]]]),
|
||||
(-2, 0, 2): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
|
||||
(-2, 1, 3): ([[[-1, 8, 6], [7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
|
||||
(-2, 2, 0): ([[[8, -1, -1], [8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
|
||||
(-1, -1, 2): ([[3, 4, 5]]), (-1, 0, 1): ([[[7, 5, 4], [3, 4, 5]]]),
|
||||
(-1, 1, 2): ([[[-1, 8, 6], [7, 5, 4], [3, 4, 5]]]),
|
||||
(-1, 2, 3): ([[[-1, -1, 8], [-1, 8, 6], [7, 5, 4], [3, 4, 5]]]), (0, 0, 0): ([[7, 5, 4]]),
|
||||
(0, 1, 1): ([[[8, 6, -1], [7, 5, 4]]]), (0, 2, 2): ([[[-1, -1, 8], [-1, 8, 6], [7, 5, 4]]]),
|
||||
(1, 1, 2): ([[8, 6]]), (1, 2, 3): ([[[-1, 8], [8, 6]]]), (2, 2, 0): ([[8]])}),
|
||||
(2, 1, 1): (([[[5]], [[6]]]), {}), (2, 3, 3): (\
|
||||
([[[7, 6, 3], [5, 8, 5], [5, 0, 2]], [[1, 8, 1], [5, 5, 8], [8, 4, 0]]]),
|
||||
{(-2, -2, 0): ([[5], [8]]), (-2, -1, 3): ([[[5, 0], [-1, 5]], [[5, 4], [-1, 8]]]),
|
||||
(-2, 0, 2): ([[[7, 8, 2], [5, 0, -1], [5, -1, -1]], [[1, 5, 0], [5, 4, -1], [8, -1, -1]]]),
|
||||
(-2, 1, 3): ([[[-1, 6, 5], [7, 8, 2], [-1, 5, 0], [-1, -1, 5]], [[-1, 8, 8], [1, 5, 0], [-1, 5, 4], [-1, -1, 8]]]),
|
||||
(-2, 2, 0): ([[[3, -1, -1], [6, 5, -1], [7, 8, 2], [-1, 5, 0], [-1, -1, 5]],
|
||||
[[1, -1, -1], [8, 8, -1], [1, 5, 0], [-1, 5, 4], [-1, -1, 8]]]), (-1, -1, 2): ([[5, 0], [5, 4]]),
|
||||
(-1, 0, 1): ([[[7, 8, 2], [5, 0, -1]], [[1, 5, 0], [5, 4, -1]]]),
|
||||
(-1, 1, 2): ([[[-1, 6, 5], [7, 8, 2], [5, 0, -1]], [[-1, 8, 8], [1, 5, 0], [5, 4, -1]]]),
|
||||
(-1, 2, 3): ([[[-1, -1, 3], [-1, 6, 5], [7, 8, 2], [-1, 5, 0]], [[-1, -1, 1], [-1, 8, 8], [1, 5, 0], [-1, 5, 4]]]),
|
||||
(0, 0, 0): ([[7, 8, 2], [1, 5, 0]]), (0, 1, 1): ([[[6, 5, -1], [7, 8, 2]], [[8, 8, -1], [1, 5, 0]]]),
|
||||
(0, 2, 2): ([[[-1, -1, 3], [-1, 6, 5], [7, 8, 2]], [[-1, -1, 1], [-1, 8, 8], [1, 5, 0]]]),
|
||||
(1, 1, 2): ([[6, 5], [8, 8]]), (1, 2, 3): ([[[-1, 3], [6, 5]], [[-1, 1], [8, 8]]])}),
|
||||
(2, 1, 2): (([[[6, 3]], [[5, 5]]]), {}), (2, 3, 5): (\
|
||||
([[[1, 2, 1, 2, 7], [0, 3, 5, 0, 2], [0, 5, 1, 7, 5]], [[3, 4, 3, 5, 7], [2, 5, 2, 7, 5], [7, 5, 1, 1, 7]]]),
|
||||
{(-2, -2, 0): ([[0], [7]]), (-2, -1, 3): ([[[0, 5], [-1, 0]], [[2, 5], [-1, 7]]]),
|
||||
(-2, 0, 2): ([[[1, 3, 1], [0, 5, -1], [0, -1, -1]], [[3, 5, 1], [2, 5, -1], [7, -1, -1]]]),
|
||||
(-2, 1, 3): ([[[2, 5, 7], [1, 3, 1], [-1, 0, 5], [-1, -1, 0]], [[4, 2, 1], [3, 5, 1], [-1, 2, 5], [-1, -1, 7]]]),
|
||||
(-2, 2, 0): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5], [-1, -1, 0]],
|
||||
[[3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5], [-1, -1, 7]]]), (-2, 3, 1): (\
|
||||
[[[2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1], [0, -1, -1]],
|
||||
[[5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1], [7, -1, -1]]]), (-2, 4, 2): (\
|
||||
[[[-1, -1, 7], [-1, 2, 2], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1], [0, -1, -1]],
|
||||
[[-1, -1, 7], [-1, 5, 5], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1], [7, -1, -1]]]),
|
||||
(-1, -1, 2): ([[0, 5], [2, 5]]), (-1, 0, 1): ([[[1, 3, 1], [0, 5, -1]], [[3, 5, 1], [2, 5, -1]]]),
|
||||
(-1, 1, 2): ([[[2, 5, 7], [1, 3, 1], [0, 5, -1]], [[4, 2, 1], [3, 5, 1], [2, 5, -1]]]),
|
||||
(-1, 2, 3): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5]], [[3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5]]]),
|
||||
(-1, 3, 0): ([[[2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5]],
|
||||
[[5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5]]]), (-1, 4, 1): (\
|
||||
[[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1]],
|
||||
[[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1]]]), (0, 0, 0): ([[1, 3, 1], [3, 5, 1]]),
|
||||
(0, 1, 1): ([[[2, 5, 7], [1, 3, 1]], [[4, 2, 1], [3, 5, 1]]]),
|
||||
(0, 2, 2): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1]], [[3, 7, 7], [4, 2, 1], [3, 5, 1]]]),
|
||||
(0, 3, 3): ([[[-1, 2, 2], [1, 0, 5], [2, 5, 7], [1, 3, 1]], [[-1, 5, 5], [3, 7, 7], [4, 2, 1], [3, 5, 1]]]),
|
||||
(0, 4, 0): ([[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1]],
|
||||
[[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1]]]), (1, 1, 2): ([[2, 5, 7], [4, 2, 1]]),
|
||||
(1, 2, 3): ([[[1, 0, 5], [2, 5, 7]], [[3, 7, 7], [4, 2, 1]]]),
|
||||
(1, 3, 0): ([[[2, 2, -1], [1, 0, 5], [2, 5, 7]], [[5, 5, -1], [3, 7, 7], [4, 2, 1]]]),
|
||||
(1, 4, 1): ([[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7]], [[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1]]])}),
|
||||
(2, 2, 1): (([[[4], [8]], [[3], [5]]]),
|
||||
{(-1, -1, 2): ([[8], [5]]), (-1, 0, 1): ([[[4], [8]], [[3], [5]]]), (0, 0, 0): ([[4], [3]])}),
|
||||
(2, 5, 3): (([[[6, 8, 5], [7, 2, 7], [2, 2, 5], [5, 6, 7], [5, 0, 2]],
|
||||
[[3, 8, 7], [7, 8, 2], [8, 1, 0], [0, 6, 5], [6, 3, 1]]]),
|
||||
{(-4, -4, 0): ([[5], [6]]), (-4, -3, 3): ([[[5, 0], [-1, 5]], [[0, 3], [-1, 6]]]),
|
||||
(-4, -2, 2): ([[[2, 6, 2], [5, 0, -1], [5, -1, -1]], [[8, 6, 1], [0, 3, -1], [6, -1, -1]]]),
|
||||
(-4, -1, 1): ([[[7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
|
||||
[[7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), (-4, 0, 0): (\
|
||||
[[[6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0], [-1, -1, 5]],
|
||||
[[3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3], [-1, -1, 6]]]), (-4, 1, 1): (\
|
||||
[[[8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
|
||||
[[8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), (-4, 2, 2): (\
|
||||
[[[-1, -1, 5], [-1, 8, 7], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
|
||||
[[-1, -1, 7], [-1, 8, 2], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]),
|
||||
(-3, -3, 2): ([[5, 0], [0, 3]]),
|
||||
(-3, -2, 1): ([[[2, 6, 2], [5, 0, -1]], [[8, 6, 1], [0, 3, -1]]]),
|
||||
(-3, -1, 0): ([[[7, 2, 7], [2, 6, 2], [-1, 5, 0]], [[7, 1, 5], [8, 6, 1], [-1, 0, 3]]]),
|
||||
(-3, 0, 3): (\
|
||||
[[[6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0]], [[3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3]]]),
|
||||
(-3, 1, 0): ([[[8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0]],
|
||||
[[8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3]]]), (-3, 2, 1): (\
|
||||
[[[5, -1, -1], [8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1]],
|
||||
[[7, -1, -1], [8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1]]]),
|
||||
(-2, -2, 0): ([[2, 6, 2], [8, 6, 1]]),
|
||||
(-2, -1, 3): ([[[7, 2, 7], [2, 6, 2]], [[7, 1, 5], [8, 6, 1]]]),
|
||||
(-2, 0, 2): ([[[6, 2, 5], [7, 2, 7], [2, 6, 2]], [[3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
|
||||
(-2, 1, 3): (\
|
||||
[[[-1, 8, 7], [6, 2, 5], [7, 2, 7], [2, 6, 2]], [[-1, 8, 2], [3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
|
||||
(-2, 2, 0): ([[[5, -1, -1], [8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2]],
|
||||
[[7, -1, -1], [8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
|
||||
(-1, -1, 2): ([[7, 2, 7], [7, 1, 5]]),
|
||||
(-1, 0, 1): ([[[6, 2, 5], [7, 2, 7]], [[3, 8, 0], [7, 1, 5]]]),
|
||||
(-1, 1, 2): ([[[-1, 8, 7], [6, 2, 5], [7, 2, 7]], [[-1, 8, 2], [3, 8, 0], [7, 1, 5]]]),
|
||||
(-1, 2, 3): ([[[-1, -1, 5], [-1, 8, 7], [6, 2, 5], [7, 2, 7]],
|
||||
[[-1, -1, 7], [-1, 8, 2], [3, 8, 0], [7, 1, 5]]]),
|
||||
(0, 0, 0): ([[6, 2, 5], [3, 8, 0]]),
|
||||
(0, 1, 1): ([[[8, 7, -1], [6, 2, 5]], [[8, 2, -1], [3, 8, 0]]]),
|
||||
(0, 2, 2): ([[[-1, -1, 5], [-1, 8, 7], [6, 2, 5]], [[-1, -1, 7], [-1, 8, 2], [3, 8, 0]]]),
|
||||
(1, 1, 2): ([[8, 7], [8, 2]]), (1, 2, 3): ([[[-1, 5], [8, 7]], [[-1, 7], [8, 2]]]),
|
||||
(2, 2, 0): ([[5], [7]])})}
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_matrix_diag_part_net_cpu():
|
||||
@pytest.mark.parametrize('array_dict', [([[[5]]], {}),
|
||||
([[[3, 1, 1], [6, 4, 4], [1, 6, 4]]],
|
||||
{(-2, -2, 0): [[1]], (-2, -1, 3): [[[6, 6], [-1, 1]]],
|
||||
(-2, 0, 2): [[[3, 4, 4], [6, 6, -1], [1, -1, -1]]],
|
||||
(-2, 1, 3): [[[-1, 1, 4], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]],
|
||||
(-2, 2, 0): [[[1, -1, -1], [1, 4, -1], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]],
|
||||
(-1, -1, 2): [[6, 6]], (-1, 0, 1): [[[3, 4, 4], [6, 6, -1]]],
|
||||
(-1, 1, 2): [[[-1, 1, 4], [3, 4, 4], [6, 6, -1]]],
|
||||
(-1, 2, 3): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4], [-1, 6, 6]]],
|
||||
(0, 0, 0): [[3, 4, 4]], (0, 1, 1): [[[1, 4, -1], [3, 4, 4]]],
|
||||
(0, 2, 2): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4]]], (1, 1, 2): [[1, 4]],
|
||||
(1, 2, 3): [[[-1, 1], [1, 4]]]}),
|
||||
([[[6, 1]]], {}),
|
||||
([[[2, 2, 4, 3, 0], [8, 5, 3, 0, 3], [6, 3, 2, 6, 7]]],
|
||||
{(-2, -2, 0): [[6]], (-2, -1, 3): [[[8, 3], [-1, 6]]],
|
||||
(-2, 0, 2): [[[2, 5, 2], [8, 3, -1], [6, -1, -1]]],
|
||||
(-2, 1, 3): [[[2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]],
|
||||
(-2, 2, 0): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]],
|
||||
(-2, 3, 1): [
|
||||
[[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1], [6, -1, -1]]],
|
||||
(-2, 4, 2): [
|
||||
[[-1, -1, 0], [-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1],
|
||||
[6, -1, -1]]], (-1, -1, 2): [[8, 3]],
|
||||
(-1, 0, 1): [[[2, 5, 2], [8, 3, -1]]],
|
||||
(-1, 1, 2): [[[2, 3, 6], [2, 5, 2], [8, 3, -1]]],
|
||||
(-1, 2, 3): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]],
|
||||
(-1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]],
|
||||
(-1, 4, 1): [
|
||||
[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1]]],
|
||||
(0, 0, 0): [[2, 5, 2]], (0, 1, 1): [[[2, 3, 6], [2, 5, 2]]],
|
||||
(0, 2, 2): [[[4, 0, 7], [2, 3, 6], [2, 5, 2]]],
|
||||
(0, 3, 3): [[[-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2]]],
|
||||
(0, 4, 0): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2]]],
|
||||
(1, 1, 2): [[2, 3, 6]], (1, 2, 3): [[[4, 0, 7], [2, 3, 6]]],
|
||||
(1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6]]],
|
||||
(1, 4, 1): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6]]]}),
|
||||
([[[5], [5]]], {(-1, -1, 2): [[5]], (-1, 0, 1): [[[5], [5]]],
|
||||
(0, 0, 0): [[5]]}),
|
||||
([[[2, 4, 1], [6, 4, 1], [0, 5, 2], [1, 6, 0], [1, 0, 7]]],
|
||||
{(-4, -4, 0): [[1]], (-4, -3, 3): [[[1, 0], [-1, 1]]],
|
||||
(-4, -2, 2): [[[0, 6, 7], [1, 0, -1], [1, -1, -1]]],
|
||||
(-4, -1, 1): [[[6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]],
|
||||
(-4, 0, 0): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0], [-1, -1, 1]]],
|
||||
(-4, 1, 1): [
|
||||
[[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]],
|
||||
(-4, 2, 2): [
|
||||
[[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1],
|
||||
[1, -1, -1]]], (-3, -3, 2): [[1, 0]],
|
||||
(-3, -2, 1): [[[0, 6, 7], [1, 0, -1]]],
|
||||
(-3, -1, 0): [[[6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
|
||||
(-3, 0, 3): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
|
||||
(-3, 1, 0): [[[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
|
||||
(-3, 2, 1): [
|
||||
[[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1]]],
|
||||
(-2, -2, 0): [[0, 6, 7]], (-2, -1, 3): [[[6, 5, 0], [0, 6, 7]]],
|
||||
(-2, 0, 2): [[[2, 4, 2], [6, 5, 0], [0, 6, 7]]],
|
||||
(-2, 1, 3): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]],
|
||||
(-2, 2, 0): [[[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]],
|
||||
(-1, -1, 2): [[6, 5, 0]], (-1, 0, 1): [[[2, 4, 2], [6, 5, 0]]],
|
||||
(-1, 1, 2): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0]]],
|
||||
(-1, 2, 3): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0]]],
|
||||
(0, 0, 0): [[2, 4, 2]], (0, 1, 1): [[[4, 1, -1], [2, 4, 2]]],
|
||||
(0, 2, 2): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2]]], (1, 1, 2): [[4, 1]],
|
||||
(1, 2, 3): [[[-1, 1], [4, 1]]], (2, 2, 0): [[1]]}),
|
||||
([[[6]], [[4]]], {}),
|
||||
([[[2, 4, 8], [3, 4, 2], [1, 6, 3]], [[6, 7, 2], [8, 2, 1], [4, 5, 5]]],
|
||||
{(-2, -2, 0): [[1], [4]], (-2, -1, 3): [[[3, 6], [-1, 1]], [[8, 5], [-1, 4]]],
|
||||
(-2, 0, 2): [[[2, 4, 3], [3, 6, -1], [1, -1, -1]],
|
||||
[[6, 2, 5], [8, 5, -1], [4, -1, -1]]],
|
||||
(-2, 1, 3): [[[-1, 4, 2], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]],
|
||||
[[-1, 7, 1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]],
|
||||
(-2, 2, 0): [[[8, -1, -1], [4, 2, -1], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]],
|
||||
[[2, -1, -1], [7, 1, -1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]],
|
||||
(-1, -1, 2): [[3, 6], [8, 5]],
|
||||
(-1, 0, 1): [[[2, 4, 3], [3, 6, -1]], [[6, 2, 5], [8, 5, -1]]],
|
||||
(-1, 1, 2): [[[-1, 4, 2], [2, 4, 3], [3, 6, -1]],
|
||||
[[-1, 7, 1], [6, 2, 5], [8, 5, -1]]],
|
||||
(-1, 2, 3): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3], [-1, 3, 6]],
|
||||
[[-1, -1, 2], [-1, 7, 1], [6, 2, 5], [-1, 8, 5]]],
|
||||
(0, 0, 0): [[2, 4, 3], [6, 2, 5]],
|
||||
(0, 1, 1): [[[4, 2, -1], [2, 4, 3]], [[7, 1, -1], [6, 2, 5]]],
|
||||
(0, 2, 2): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3]],
|
||||
[[-1, -1, 2], [-1, 7, 1], [6, 2, 5]]],
|
||||
(1, 1, 2): [[4, 2], [7, 1]],
|
||||
(1, 2, 3): [[[-1, 8], [4, 2]], [[-1, 2], [7, 1]]]}),
|
||||
([[[4, 0]], [[7, 4]]], {}),
|
||||
([[[3, 5, 8, 3, 5], [7, 8, 1, 0, 6], [5, 4, 0, 3, 6]],
|
||||
[[7, 4, 8, 7, 3], [4, 6, 5, 7, 1], [5, 3, 1, 1, 0]]],
|
||||
{(-2, -2, 0): [[5], [5]], (-2, -1, 3): [[[7, 4], [-1, 5]], [[4, 3], [-1, 5]]],
|
||||
(-2, 0, 2): [[[3, 8, 0], [7, 4, -1], [5, -1, -1]],
|
||||
[[7, 6, 1], [4, 3, -1], [5, -1, -1]]],
|
||||
(-2, 1, 3): [[[5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]],
|
||||
[[4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]],
|
||||
(-2, 2, 0): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]],
|
||||
[[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]],
|
||||
(-2, 3, 1): [
|
||||
[[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1], [5, -1, -1]],
|
||||
[[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1], [5, -1, -1]]],
|
||||
(-2, 4, 2): [
|
||||
[[-1, -1, 5], [-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1],
|
||||
[5, -1, -1]],
|
||||
[[-1, -1, 3], [-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1],
|
||||
[5, -1, -1]]],
|
||||
(-1, -1, 2): [[7, 4], [4, 3]],
|
||||
(-1, 0, 1): [[[3, 8, 0], [7, 4, -1]], [[7, 6, 1], [4, 3, -1]]],
|
||||
(-1, 1, 2): [[[5, 1, 3], [3, 8, 0], [7, 4, -1]],
|
||||
[[4, 5, 1], [7, 6, 1], [4, 3, -1]]],
|
||||
(-1, 2, 3): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]],
|
||||
[[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]],
|
||||
(-1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]],
|
||||
[[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]],
|
||||
(-1, 4, 1): [
|
||||
[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1]],
|
||||
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1]]],
|
||||
(0, 0, 0): [[3, 8, 0], [7, 6, 1]],
|
||||
(0, 1, 1): [[[5, 1, 3], [3, 8, 0]], [[4, 5, 1], [7, 6, 1]]],
|
||||
(0, 2, 2): [[[8, 0, 6], [5, 1, 3], [3, 8, 0]],
|
||||
[[8, 7, 0], [4, 5, 1], [7, 6, 1]]],
|
||||
(0, 3, 3): [[[-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0]],
|
||||
[[-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]],
|
||||
(0, 4, 0): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0]],
|
||||
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]],
|
||||
(1, 1, 2): [[5, 1, 3], [4, 5, 1]],
|
||||
(1, 2, 3): [[[8, 0, 6], [5, 1, 3]], [[8, 7, 0], [4, 5, 1]]],
|
||||
(1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3]],
|
||||
[[7, 1, -1], [8, 7, 0], [4, 5, 1]]],
|
||||
(1, 4, 1): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3]],
|
||||
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1]]]}),
|
||||
([[[4], [7]], [[3], [5]]],
|
||||
{(-1, -1, 2): [[7], [5]], (-1, 0, 1): [[[4], [7]], [[3], [5]]],
|
||||
(0, 0, 0): [[4], [3]]}),
|
||||
([[[0, 2, 2], [0, 0, 5], [6, 5, 5], [5, 8, 5], [3, 8, 0]],
|
||||
[[2, 8, 3], [4, 4, 1], [0, 4, 2], [0, 7, 0], [0, 7, 4]]],
|
||||
{(-4, -4, 0): [[3], [0]], (-4, -3, 3): [[[5, 8], [-1, 3]], [[0, 7], [-1, 0]]],
|
||||
(-4, -2, 2): [[[6, 8, 0], [5, 8, -1], [3, -1, -1]],
|
||||
[[0, 7, 4], [0, 7, -1], [0, -1, -1]]],
|
||||
(-4, -1, 1): [[[0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]],
|
||||
[[4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]],
|
||||
(-4, 0, 0): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8], [-1, -1, 3]],
|
||||
[[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7], [-1, -1, 0]]],
|
||||
(-4, 1, 1): [
|
||||
[[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]],
|
||||
[[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]],
|
||||
(-4, 2, 2): [
|
||||
[[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1],
|
||||
[3, -1, -1]],
|
||||
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1],
|
||||
[0, -1, -1]]], (-3, -3, 2): [[5, 8], [0, 7]],
|
||||
(-3, -2, 1): [[[6, 8, 0], [5, 8, -1]], [[0, 7, 4], [0, 7, -1]]],
|
||||
(-3, -1, 0): [[[0, 5, 5], [6, 8, 0], [-1, 5, 8]],
|
||||
[[4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
|
||||
(-3, 0, 3): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]],
|
||||
[[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
|
||||
(-3, 1, 0): [[[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]],
|
||||
[[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
|
||||
(-3, 2, 1): [
|
||||
[[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1]],
|
||||
[[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1]]],
|
||||
(-2, -2, 0): [[6, 8, 0], [0, 7, 4]],
|
||||
(-2, -1, 3): [[[0, 5, 5], [6, 8, 0]], [[4, 4, 0], [0, 7, 4]]],
|
||||
(-2, 0, 2): [[[0, 0, 5], [0, 5, 5], [6, 8, 0]],
|
||||
[[2, 4, 2], [4, 4, 0], [0, 7, 4]]],
|
||||
(-2, 1, 3): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0]],
|
||||
[[-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]],
|
||||
(-2, 2, 0): [[[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0]],
|
||||
[[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]],
|
||||
(-1, -1, 2): [[0, 5, 5], [4, 4, 0]],
|
||||
(-1, 0, 1): [[[0, 0, 5], [0, 5, 5]], [[2, 4, 2], [4, 4, 0]]],
|
||||
(-1, 1, 2): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5]],
|
||||
[[-1, 8, 1], [2, 4, 2], [4, 4, 0]]],
|
||||
(-1, 2, 3): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5]],
|
||||
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0]]],
|
||||
(0, 0, 0): [[0, 0, 5], [2, 4, 2]],
|
||||
(0, 1, 1): [[[2, 5, -1], [0, 0, 5]], [[8, 1, -1], [2, 4, 2]]],
|
||||
(0, 2, 2): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5]],
|
||||
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2]]],
|
||||
(1, 1, 2): [[2, 5], [8, 1]],
|
||||
(1, 2, 3): [[[-1, 2], [2, 5]], [[-1, 3], [8, 1]]], (2, 2, 0): [[2], [3]]})])
|
||||
def test_matrix_diag_part_net(array_dict):
|
||||
"""
|
||||
testcase generate from below
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tf.python.ops import array_ops
|
||||
import numpy as np
|
||||
f = open (r'dict.tst','w')
|
||||
aligndict = {0: "LEFT_RIGHT", 1:"LEFT_LEFT", 2:"RIGHT_LEFT", 3:"RIGHT_RIGHT"}
|
||||
Adict={}
|
||||
Adict=[]
|
||||
for i in [1, 2]:
|
||||
for m,n in [(1, 1), (3,3),(1, 2),(3, 5),(2, 1),(5, 3)]:
|
||||
A = np.array(np.random.randint(20, size=(i, m, n)))
|
||||
Adict[i,m,n]=A
|
||||
p = -1
|
||||
kadict={}
|
||||
for k0 in range(-m + 1, m - 1):
|
||||
for k1 in range(k0, n):
|
||||
|
@ -187,7 +223,7 @@ def test_matrix_diag_part_net_cpu():
|
|||
ka = (k,align_)
|
||||
B = array_ops.matrix_diag_part(A, k=k, align=aligndict[align_], padding_value=-1)
|
||||
kadict[ka] = B.numpy()
|
||||
Adict[i,m,n]=(A, kadict)
|
||||
Adict.append(A, kadict)
|
||||
print(Adict, file= f)
|
||||
f.close()
|
||||
Feature: ALL To ALL
|
||||
|
@ -195,12 +231,11 @@ def test_matrix_diag_part_net_cpu():
|
|||
Expectation: the result match to numpy
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for _, value in Adict.items():
|
||||
a, kadict = value
|
||||
for key1, b in kadict.items():
|
||||
k0, k1, align_ = key1
|
||||
if k0 == k1:
|
||||
r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), k0, PAD_VALUE, align=aligndict[align_])
|
||||
else:
|
||||
r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), (k0, k1), PAD_VALUE, align=aligndict[align_])
|
||||
match_matrix(Tensor(b), Tensor(r_b))
|
||||
a, kadict = array_dict
|
||||
for key1, b in kadict.items():
|
||||
k0, k1, align_ = key1
|
||||
if k0 == k1:
|
||||
r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), k0, PAD_VALUE, align=aligndict[align_])
|
||||
else:
|
||||
r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), (k0, k1), PAD_VALUE, align=aligndict[align_])
|
||||
match_array(b, r_b.asnumpy())
|
||||
|
|
Loading…
Reference in New Issue