!29352 Add matrix_diag_part cpu kernel

Merge pull request !29352 from wuwenbing/master
This commit is contained in:
i-robot 2022-01-22 07:55:23 +00:00 committed by Gitee
commit eb47985854
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 528 additions and 2 deletions

View File

@ -25,7 +25,7 @@ void MatrixBandPartCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
dim_size_ = shapes_.size();
if (shapes_.size() < kDim2) {
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix max than 2.";
MS_LOG(EXCEPTION) << "Wrong array shape, A should be a matrix max than 2.";
}
m_ = shapes_[dim_size_ - kDim2];
n_ = shapes_[dim_size_ - kDim1];
@ -53,7 +53,7 @@ bool MatrixBandPartCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, c
for (size_t k = 0; k < out_range_size_; k++) {
for (size_t i = 0; i < std::min(m_, l + n_); i++) {
const size_t s = i < l ? 0 : i - l;
// when i = n - u, end is n -1, because end pos is start from 0
// When i = n - u, end is n -1, because end pos is start from 0
const size_t e = i >= n_ - u ? n_ - 1 : i + u;
const size_t offset = k * m_ * n_ + i * n_;
memcpy_s(out_value + offset + s, matrix_size_ * sizeof(T), in_value + offset + s, (e - s + 1) * sizeof(T));

View File

@ -0,0 +1,107 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/matrix_diag_part_cpu_kernel.h"
#include <algorithm>
#include <string>
namespace mindspore {
namespace kernel {
template <typename T>
void MatrixDiagPartCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
out_shapes_ = shapes_;
dim_size_ = shapes_.size();
if (shapes_.size() < kDim2) {
MS_LOG(EXCEPTION) << "Wrong array shape, A should be a matrix max than 2.";
}
m_ = shapes_[dim_size_ - kDim2];
n_ = shapes_[dim_size_ - kDim1];
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
out_range_size_ *= shapes_[i];
}
// Invalid alignment will throw an exception.
auto alignment = AnfAlgo::GetNodeAttr<std::string>(kernel_node, ALIGNMENT);
alignment_ = GetAlignments(alignment);
node_wpt_ = kernel_node;
}
template <typename T>
bool MatrixDiagPartCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *in_value = reinterpret_cast<T *>(inputs[0]->addr);
// K is 2 elements vector, k[0] is lower part, k[0]<0, k[1] is upper part,
int64_t *k_range = reinterpret_cast<int64_t *>(inputs[1]->addr);
T *padding_value = reinterpret_cast<T *>(inputs[2]->addr);
T *out_value = reinterpret_cast<T *>(outputs[0]->addr);
int64_t l = k_range[0];
int64_t u = k_range[1];
// New diagonal matrix m*n matrix, m dimension ;
if (l > u) {
MS_LOG(EXCEPTION) << "The k[1] must not less than k[0].";
}
u = std::min(u, n_ - 1);
l = std::max(-(m_ - 1), l);
int64_t num_diags = u - l + 1;
// New diagonal matrix m * n matrix, n dimension
int64_t max_diag_len =
std::min(m_ + std::min(u, static_cast<int64_t>(0)), n_ + std::min(-l, static_cast<int64_t>(0)));
MS_LOG(DEBUG) << "Num_diags:" << num_diags << ",max_diag_len:" << max_diag_len;
int64_t dest_inner_matrix_size = num_diags * max_diag_len;
out_shapes_ = shapes_;
// Set dynamic shape and dtype
if (!node_wpt_.expired()) {
auto node_ = node_wpt_.lock();
out_shapes_[shapes_.size() - kDim1] = max_diag_len;
// If the out shape m' * n', the m' dimension is 1, then remove this dimension
out_shapes_[shapes_.size() - kDim2] = num_diags;
if (num_diags == 1) {
out_shapes_.erase(out_shapes_.begin() + shapes_.size() - kDim2);
}
auto dtype = AnfAlgo::GetOutputDeviceDataType(node_, 0);
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {out_shapes_}, node_.get());
}
for (int64_t i = 0; i < out_range_size_; i++) {
// The j_index means current dest row index
for (int64_t j = u; j >= l; j--) {
int64_t current_diag_len = j >= 0 ? std::min(n_ - j, m_) : std::min(m_ + j, n_);
int64_t current_pad_len = max_diag_len - current_diag_len;
// Pad left by default
bool pad_left = (alignment_.first == MatrixDiag::Alignment::RIGHT && j > 0) ||
(alignment_.second == MatrixDiag::Alignment::RIGHT && j < 0);
// Set none-padding values, l means current diag col index
for (int64_t k = 0; k < max_diag_len; k++) {
// Source pos, k offset, only effective when pad left
int64_t k_offset = (pad_left && k >= current_pad_len) ? k - current_pad_len : k;
// Calculate source offset row/col offset
size_t row_index = j >= 0 ? j + k_offset : k_offset;
size_t col_index = j >= 0 ? k_offset : k_offset - j;
size_t source_offset = i * m_ * n_ + col_index * n_ + row_index;
// If current pos need pad, then the value is pad value
bool current_pad_flag = (pad_left && k < current_pad_len) || (!pad_left && k >= current_diag_len);
T current_pad_value = current_pad_flag ? *padding_value : *(in_value + source_offset);
int64_t j_index = u - j;
size_t dest_offset = dest_inner_matrix_size * i + j_index * max_diag_len + k;
MS_LOG(DEBUG) << "the diag j:" << j << ",k:" << k << ",k_offset:" << k_offset << ",row:" << row_index
<< ",col:" << col_index << ",j_index:" << j_index << ",current_pad_value:" << current_pad_value;
*(out_value + dest_offset) = current_pad_value;
}
}
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,82 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_DIAG_PART_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_DIAG_PART_H
#include <vector>
#include <complex>
#include <utility>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class MatrixDiagPartCPUKernel : public CPUKernel {
public:
MatrixDiagPartCPUKernel() = default;
~MatrixDiagPartCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
// <Super_matrix_diag_align, Sub_matrix_diag_align>
std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> alignment_{MatrixDiag::RIGHT, MatrixDiag::LEFT};
std::vector<size_t> shapes_{};
int64_t out_range_size_{1};
size_t dim_size_{1};
int64_t m_{1};
int64_t n_{1};
std::vector<size_t> out_shapes_{};
CNodeWeakPtr node_wpt_;
};
MS_REG_CPU_KERNEL_T(MatrixDiagPart,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
MatrixDiagPartCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(MatrixDiagPart,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
MatrixDiagPartCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(MatrixDiagPart,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MatrixDiagPartCPUKernel, float);
MS_REG_CPU_KERNEL_T(MatrixDiagPart,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
MatrixDiagPartCPUKernel, double);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_DIAG_PART_H

View File

@ -575,6 +575,7 @@ inline const PrimitivePtr kPrimAddcmul = std::make_shared<Primitive>(kAddcmul);
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
inline const PrimitivePtr kPrimMatMulV2 = std::make_shared<Primitive>("MatMulV2");
inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag");
inline const PrimitivePtr kPrimMatrixDiagPart = std::make_shared<Primitive>("MatrixDiagPart");
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
inline const PrimitivePtr kPrimBatchMatMulV2 = std::make_shared<Primitive>("BatchMatMulV2");
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/matrix_diag_part.h"
#include <set>
#include "abstract/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
#include "abstract/utils.h"
namespace mindspore {
namespace ops {
namespace {
const constexpr int64_t kShape2 = 2;
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto input_shape = input_args[0]->BuildShape();
auto shape_element = input_shape->cast<abstract::ShapePtr>();
ShapeVector shape = shape_element->shape();
ShapeVector min_shape = shape_element->shape();
ShapeVector max_shape = shape_element->shape();
max_shape[shape.size() - 1] = kShape2 * shape[shape.size() - 1] - 1;
min_shape[shape.size() - 1] = 1;
shape[shape.size() - 1] = abstract::Shape::SHP_ANY;
return std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(infer_type);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("input", infer_type, valid_types, prim->name());
return infer_type;
}
} // namespace
AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagPart, prim::kPrimMatrixDiagPart, MatrixDiagPartInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
#define MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMatrixDiagPart = "MatrixDiagPart";
/// \brief get the specified part of the inner most diag matrix of a matrix, fill with padding value .
/// Refer to Python API @ref mindspore.ops.MatrixDiagPart for more details.
class MatrixDiagPart : public PrimitiveC {
public:
/// \brief Constructor.
MatrixDiagPart() : PrimitiveC(kNameMatrixDiagPart) { InitIOName({"input", "k", "padding_value"}, {"output"}); }
/// \brief Destructor.
~MatrixDiagPart() = default;
MS_DECLARE_PARENT(MatrixDiagPart, PrimitiveC);
};
AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_

View File

@ -390,4 +390,33 @@ class MatrixBandPartNet(nn.Cell):
return self.matrix_band_part(a, num_lower, num_upper)
class MatrixDiagPart(PrimitiveWithInfer):
"""
MatrixDiagPart()
"""
@prim_attr_register
def __init__(self, align="RIGHT_LEFT"):
super().__init__(name="MatrixDiagPart")
self.add_prim_attr('alignment', align)
self.init_prim_io_names(inputs=['A', 'k', 'padding_value'], outputs=['output'])
class MatrixDiagPartNet(nn.Cell):
"""
Returns:
batched diagonal part of a batched tensor, the part between, k[0] to k[1], the shape is dynamic
Raises:
k[1] should not less then k[0]
"""
def __init__(self, align="RIGHT_LEFT"):
super(MatrixDiagPartNet, self).__init__()
self.matrix_diag_part = MatrixDiagPart(align)
def construct(self, a, k, padding_value):
return self.matrix_diag_part(a, k, padding_value)
from .ops_grad import get_bprpo_eigh

View File

@ -0,0 +1,204 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""st for scipy.ops_wrapper."""
import pytest
from mindspore import context, Tensor
from mindspore.scipy.ops import MatrixDiagPartNet
from tests.st.scipy_st.utils import match_matrix
aligndict = {0: "LEFT_RIGHT", 1: "LEFT_LEFT", 2: "RIGHT_LEFT", 3: "RIGHT_RIGHT"}
PAD_VALUE = -1
Adict = {(1, 1, 1): (([[[1]]]), {}), (1, 3, 3): (([[[8, 2, 1], [5, 3, 7], [0, 3, 4]]]),
{(-2, -2, 0): ([[0]]), (-2, -1, 3): ([[[5, 3], [-1, 0]]]),
(-2, 0, 2): ([[[8, 3, 4], [5, 3, -1], [0, -1, -1]]]),
(-2, 1, 3): ([[[-1, 2, 7], [8, 3, 4], [-1, 5, 3], [-1, -1, 0]]]),
(-2, 2, 0): (\
[[[1, -1, -1], [2, 7, -1], [8, 3, 4], [-1, 5, 3], [-1, -1, 0]]]),
(-1, -1, 2): ([[5, 3]]), (-1, 0, 1): ([[[8, 3, 4], [5, 3, -1]]]),
(-1, 1, 2): ([[[-1, 2, 7], [8, 3, 4], [5, 3, -1]]]),
(-1, 2, 3): ([[[-1, -1, 1], [-1, 2, 7], [8, 3, 4], [-1, 5, 3]]]),
(0, 0, 0): ([[8, 3, 4]]), (0, 1, 1): ([[[2, 7, -1], [8, 3, 4]]]),
(0, 2, 2): ([[[-1, -1, 1], [-1, 2, 7], [8, 3, 4]]]),
(1, 1, 2): ([[2, 7]]), (1, 2, 3): ([[[-1, 1], [2, 7]]])}),
(1, 1, 2): (([[[3, 2]]]), {}), (1, 3, 5): (([[[3, 5, 5, 2, 5], [0, 6, 2, 4, 7], [7, 3, 3, 6, 8]]]),
{(-2, -2, 0): ([[7]]), (-2, -1, 3): ([[[0, 3], [-1, 7]]]),
(-2, 0, 2): ([[[3, 6, 3], [0, 3, -1], [7, -1, -1]]]),
(-2, 1, 3): ([[[5, 2, 6], [3, 6, 3], [-1, 0, 3], [-1, -1, 7]]]),
(-2, 2, 0): (\
[[[5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3], [-1, -1, 7]]]),
(-2, 3, 1): ([
[[2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], [0, 3, -1],
[7, -1, -1]]]), (-2, 4, 2): ([\
[[-1, -1, 5], [-1, 2, 7], [5, 4, 8], [5, 2, 6], [3, 6, 3],
[0, 3, -1], [7, -1, -1]]]), (-1, -1, 2): ([[0, 3]]),
(-1, 0, 1): ([[[3, 6, 3], [0, 3, -1]]]),
(-1, 1, 2): ([[[5, 2, 6], [3, 6, 3], [0, 3, -1]]]),
(-1, 2, 3): ([[[5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3]]]),
(-1, 3, 0): (\
[[[2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3]]]),
(-1, 4, 1): ([
[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3],
[0, 3, -1]]]), (0, 0, 0): ([[3, 6, 3]]),
(0, 1, 1): ([[[5, 2, 6], [3, 6, 3]]]),
(0, 2, 2): ([[[5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
(0, 3, 3): ([[[-1, 2, 7], [5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
(0, 4, 0): (\
[[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
(1, 1, 2): ([[5, 2, 6]]), (1, 2, 3): ([[[5, 4, 8], [5, 2, 6]]]),
(1, 3, 0): ([[[2, 7, -1], [5, 4, 8], [5, 2, 6]]]),
(1, 4, 1): ([[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6]]])}),
(1, 2, 1): (([[[4], [1]]]), {(-1, -1, 2): ([[1]]), (-1, 0, 1): ([[[4], [1]]]), (0, 0, 0): ([[4]])}),
(1, 5, 3): (([[[7, 8, 8], [3, 5, 6], [0, 4, 4], [8, 4, 5], [0, 4, 6]]]),
{(-4, -4, 0): ([[0]]), (-4, -3, 3): ([[[8, 4], [-1, 0]]]),
(-4, -2, 2): ([[[0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
(-4, -1, 1): ([[[3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
(-4, 0, 0): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4], [-1, -1, 0]]]),
(-4, 1, 1): ([[[8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
(-4, 2, 2): (\
[[[-1, -1, 8], [-1, 8, 6], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
(-3, -3, 2): ([[8, 4]]), (-3, -2, 1): ([[[0, 4, 6], [8, 4, -1]]]),
(-3, -1, 0): ([[[3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
(-3, 0, 3): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
(-3, 1, 0): ([[[8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
(-3, 2, 1): ([[[8, -1, -1], [8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1]]]),
(-2, -2, 0): ([[0, 4, 6]]), (-2, -1, 3): ([[[3, 4, 5], [0, 4, 6]]]),
(-2, 0, 2): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
(-2, 1, 3): ([[[-1, 8, 6], [7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
(-2, 2, 0): ([[[8, -1, -1], [8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
(-1, -1, 2): ([[3, 4, 5]]), (-1, 0, 1): ([[[7, 5, 4], [3, 4, 5]]]),
(-1, 1, 2): ([[[-1, 8, 6], [7, 5, 4], [3, 4, 5]]]),
(-1, 2, 3): ([[[-1, -1, 8], [-1, 8, 6], [7, 5, 4], [3, 4, 5]]]), (0, 0, 0): ([[7, 5, 4]]),
(0, 1, 1): ([[[8, 6, -1], [7, 5, 4]]]), (0, 2, 2): ([[[-1, -1, 8], [-1, 8, 6], [7, 5, 4]]]),
(1, 1, 2): ([[8, 6]]), (1, 2, 3): ([[[-1, 8], [8, 6]]]), (2, 2, 0): ([[8]])}),
(2, 1, 1): (([[[5]], [[6]]]), {}), (2, 3, 3): (\
([[[7, 6, 3], [5, 8, 5], [5, 0, 2]], [[1, 8, 1], [5, 5, 8], [8, 4, 0]]]),
{(-2, -2, 0): ([[5], [8]]), (-2, -1, 3): ([[[5, 0], [-1, 5]], [[5, 4], [-1, 8]]]),
(-2, 0, 2): ([[[7, 8, 2], [5, 0, -1], [5, -1, -1]], [[1, 5, 0], [5, 4, -1], [8, -1, -1]]]),
(-2, 1, 3): ([[[-1, 6, 5], [7, 8, 2], [-1, 5, 0], [-1, -1, 5]], [[-1, 8, 8], [1, 5, 0], [-1, 5, 4], [-1, -1, 8]]]),
(-2, 2, 0): ([[[3, -1, -1], [6, 5, -1], [7, 8, 2], [-1, 5, 0], [-1, -1, 5]],
[[1, -1, -1], [8, 8, -1], [1, 5, 0], [-1, 5, 4], [-1, -1, 8]]]), (-1, -1, 2): ([[5, 0], [5, 4]]),
(-1, 0, 1): ([[[7, 8, 2], [5, 0, -1]], [[1, 5, 0], [5, 4, -1]]]),
(-1, 1, 2): ([[[-1, 6, 5], [7, 8, 2], [5, 0, -1]], [[-1, 8, 8], [1, 5, 0], [5, 4, -1]]]),
(-1, 2, 3): ([[[-1, -1, 3], [-1, 6, 5], [7, 8, 2], [-1, 5, 0]], [[-1, -1, 1], [-1, 8, 8], [1, 5, 0], [-1, 5, 4]]]),
(0, 0, 0): ([[7, 8, 2], [1, 5, 0]]), (0, 1, 1): ([[[6, 5, -1], [7, 8, 2]], [[8, 8, -1], [1, 5, 0]]]),
(0, 2, 2): ([[[-1, -1, 3], [-1, 6, 5], [7, 8, 2]], [[-1, -1, 1], [-1, 8, 8], [1, 5, 0]]]),
(1, 1, 2): ([[6, 5], [8, 8]]), (1, 2, 3): ([[[-1, 3], [6, 5]], [[-1, 1], [8, 8]]])}),
(2, 1, 2): (([[[6, 3]], [[5, 5]]]), {}), (2, 3, 5): (\
([[[1, 2, 1, 2, 7], [0, 3, 5, 0, 2], [0, 5, 1, 7, 5]], [[3, 4, 3, 5, 7], [2, 5, 2, 7, 5], [7, 5, 1, 1, 7]]]),
{(-2, -2, 0): ([[0], [7]]), (-2, -1, 3): ([[[0, 5], [-1, 0]], [[2, 5], [-1, 7]]]),
(-2, 0, 2): ([[[1, 3, 1], [0, 5, -1], [0, -1, -1]], [[3, 5, 1], [2, 5, -1], [7, -1, -1]]]),
(-2, 1, 3): ([[[2, 5, 7], [1, 3, 1], [-1, 0, 5], [-1, -1, 0]], [[4, 2, 1], [3, 5, 1], [-1, 2, 5], [-1, -1, 7]]]),
(-2, 2, 0): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5], [-1, -1, 0]],
[[3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5], [-1, -1, 7]]]), (-2, 3, 1): (\
[[[2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1], [0, -1, -1]],
[[5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1], [7, -1, -1]]]), (-2, 4, 2): (\
[[[-1, -1, 7], [-1, 2, 2], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1], [0, -1, -1]],
[[-1, -1, 7], [-1, 5, 5], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1], [7, -1, -1]]]),
(-1, -1, 2): ([[0, 5], [2, 5]]), (-1, 0, 1): ([[[1, 3, 1], [0, 5, -1]], [[3, 5, 1], [2, 5, -1]]]),
(-1, 1, 2): ([[[2, 5, 7], [1, 3, 1], [0, 5, -1]], [[4, 2, 1], [3, 5, 1], [2, 5, -1]]]),
(-1, 2, 3): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5]], [[3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5]]]),
(-1, 3, 0): ([[[2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5]],
[[5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5]]]), (-1, 4, 1): (\
[[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1]],
[[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1]]]), (0, 0, 0): ([[1, 3, 1], [3, 5, 1]]),
(0, 1, 1): ([[[2, 5, 7], [1, 3, 1]], [[4, 2, 1], [3, 5, 1]]]),
(0, 2, 2): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1]], [[3, 7, 7], [4, 2, 1], [3, 5, 1]]]),
(0, 3, 3): ([[[-1, 2, 2], [1, 0, 5], [2, 5, 7], [1, 3, 1]], [[-1, 5, 5], [3, 7, 7], [4, 2, 1], [3, 5, 1]]]),
(0, 4, 0): ([[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1]],
[[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1]]]), (1, 1, 2): ([[2, 5, 7], [4, 2, 1]]),
(1, 2, 3): ([[[1, 0, 5], [2, 5, 7]], [[3, 7, 7], [4, 2, 1]]]),
(1, 3, 0): ([[[2, 2, -1], [1, 0, 5], [2, 5, 7]], [[5, 5, -1], [3, 7, 7], [4, 2, 1]]]),
(1, 4, 1): ([[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7]], [[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1]]])}),
(2, 2, 1): (([[[4], [8]], [[3], [5]]]),
{(-1, -1, 2): ([[8], [5]]), (-1, 0, 1): ([[[4], [8]], [[3], [5]]]), (0, 0, 0): ([[4], [3]])}),
(2, 5, 3): (([[[6, 8, 5], [7, 2, 7], [2, 2, 5], [5, 6, 7], [5, 0, 2]],
[[3, 8, 7], [7, 8, 2], [8, 1, 0], [0, 6, 5], [6, 3, 1]]]),
{(-4, -4, 0): ([[5], [6]]), (-4, -3, 3): ([[[5, 0], [-1, 5]], [[0, 3], [-1, 6]]]),
(-4, -2, 2): ([[[2, 6, 2], [5, 0, -1], [5, -1, -1]], [[8, 6, 1], [0, 3, -1], [6, -1, -1]]]),
(-4, -1, 1): ([[[7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
[[7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), (-4, 0, 0): (\
[[[6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0], [-1, -1, 5]],
[[3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3], [-1, -1, 6]]]), (-4, 1, 1): (\
[[[8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
[[8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), (-4, 2, 2): (\
[[[-1, -1, 5], [-1, 8, 7], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
[[-1, -1, 7], [-1, 8, 2], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]),
(-3, -3, 2): ([[5, 0], [0, 3]]),
(-3, -2, 1): ([[[2, 6, 2], [5, 0, -1]], [[8, 6, 1], [0, 3, -1]]]),
(-3, -1, 0): ([[[7, 2, 7], [2, 6, 2], [-1, 5, 0]], [[7, 1, 5], [8, 6, 1], [-1, 0, 3]]]),
(-3, 0, 3): (\
[[[6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0]], [[3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3]]]),
(-3, 1, 0): ([[[8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0]],
[[8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3]]]), (-3, 2, 1): (\
[[[5, -1, -1], [8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1]],
[[7, -1, -1], [8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1]]]),
(-2, -2, 0): ([[2, 6, 2], [8, 6, 1]]),
(-2, -1, 3): ([[[7, 2, 7], [2, 6, 2]], [[7, 1, 5], [8, 6, 1]]]),
(-2, 0, 2): ([[[6, 2, 5], [7, 2, 7], [2, 6, 2]], [[3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
(-2, 1, 3): (\
[[[-1, 8, 7], [6, 2, 5], [7, 2, 7], [2, 6, 2]], [[-1, 8, 2], [3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
(-2, 2, 0): ([[[5, -1, -1], [8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2]],
[[7, -1, -1], [8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
(-1, -1, 2): ([[7, 2, 7], [7, 1, 5]]),
(-1, 0, 1): ([[[6, 2, 5], [7, 2, 7]], [[3, 8, 0], [7, 1, 5]]]),
(-1, 1, 2): ([[[-1, 8, 7], [6, 2, 5], [7, 2, 7]], [[-1, 8, 2], [3, 8, 0], [7, 1, 5]]]),
(-1, 2, 3): ([[[-1, -1, 5], [-1, 8, 7], [6, 2, 5], [7, 2, 7]],
[[-1, -1, 7], [-1, 8, 2], [3, 8, 0], [7, 1, 5]]]),
(0, 0, 0): ([[6, 2, 5], [3, 8, 0]]),
(0, 1, 1): ([[[8, 7, -1], [6, 2, 5]], [[8, 2, -1], [3, 8, 0]]]),
(0, 2, 2): ([[[-1, -1, 5], [-1, 8, 7], [6, 2, 5]], [[-1, -1, 7], [-1, 8, 2], [3, 8, 0]]]),
(1, 1, 2): ([[8, 7], [8, 2]]), (1, 2, 3): ([[[-1, 5], [8, 7]], [[-1, 7], [8, 2]]]),
(2, 2, 0): ([[5], [7]])})}
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_matrix_diag_part_net_cpu():
"""
testcase generate from below
from tensorflow.python.ops import array_ops
import numpy as np
f = open (r'dict.tst','w')
aligndict = {0: "LEFT_RIGHT", 1:"LEFT_LEFT", 2:"RIGHT_LEFT", 3:"RIGHT_RIGHT"}
Adict={}
for i in [1, 2]:
for m,n in [(1, 1), (3,3),(1, 2),(3, 5),(2, 1),(5, 3)]:
A = np.array(np.random.randint(20, size=(i, m, n)))
Adict[i,m,n]=A
p = -1
kadict={}
for k0 in range(-m + 1, m - 1):
for k1 in range(k0, n):
k = (k0, k1)
align_= (abs(k0)+ abs(k1)) % 4
ka = (k,align_)
B = array_ops.matrix_diag_part(A, k=k, align=aligndict[align_], padding_value=-1)
kadict[ka] = B.numpy()
Adict[i,m,n]=(A, kadict)
print(Adict, file= f)
f.close()
Feature: ALL To ALL
Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0
Expectation: the result match to numpy
"""
context.set_context(mode=context.PYNATIVE_MODE)
for _, value in Adict.items():
a, kadict = value
for key1, b in kadict.items():
k0, k1, align_ = key1
msp_matrixdiagpart = MatrixDiagPartNet(align=aligndict[align_])
r_b = msp_matrixdiagpart(Tensor(a), Tensor([k0, k1]), Tensor(PAD_VALUE))
match_matrix(Tensor(b), Tensor(r_b))