!29352 Add matrix_diag_part cpu kernel
Merge pull request !29352 from wuwenbing/master
This commit is contained in:
commit
eb47985854
|
@ -25,7 +25,7 @@ void MatrixBandPartCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
dim_size_ = shapes_.size();
|
||||
if (shapes_.size() < kDim2) {
|
||||
MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix max than 2.";
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A should be a matrix max than 2.";
|
||||
}
|
||||
m_ = shapes_[dim_size_ - kDim2];
|
||||
n_ = shapes_[dim_size_ - kDim1];
|
||||
|
@ -53,7 +53,7 @@ bool MatrixBandPartCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, c
|
|||
for (size_t k = 0; k < out_range_size_; k++) {
|
||||
for (size_t i = 0; i < std::min(m_, l + n_); i++) {
|
||||
const size_t s = i < l ? 0 : i - l;
|
||||
// when i = n - u, end is n -1, because end pos is start from 0
|
||||
// When i = n - u, end is n -1, because end pos is start from 0
|
||||
const size_t e = i >= n_ - u ? n_ - 1 : i + u;
|
||||
const size_t offset = k * m_ * n_ + i * n_;
|
||||
memcpy_s(out_value + offset + s, matrix_size_ * sizeof(T), in_value + offset + s, (e - s + 1) * sizeof(T));
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/matrix_diag_part_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void MatrixDiagPartCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
out_shapes_ = shapes_;
|
||||
dim_size_ = shapes_.size();
|
||||
if (shapes_.size() < kDim2) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A should be a matrix max than 2.";
|
||||
}
|
||||
m_ = shapes_[dim_size_ - kDim2];
|
||||
n_ = shapes_[dim_size_ - kDim1];
|
||||
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
|
||||
out_range_size_ *= shapes_[i];
|
||||
}
|
||||
// Invalid alignment will throw an exception.
|
||||
auto alignment = AnfAlgo::GetNodeAttr<std::string>(kernel_node, ALIGNMENT);
|
||||
alignment_ = GetAlignments(alignment);
|
||||
node_wpt_ = kernel_node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MatrixDiagPartCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *in_value = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
// K is 2 elements vector, k[0] is lower part, k[0]<0, k[1] is upper part,
|
||||
int64_t *k_range = reinterpret_cast<int64_t *>(inputs[1]->addr);
|
||||
T *padding_value = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
T *out_value = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
int64_t l = k_range[0];
|
||||
int64_t u = k_range[1];
|
||||
// New diagonal matrix m*n matrix, m dimension ;
|
||||
if (l > u) {
|
||||
MS_LOG(EXCEPTION) << "The k[1] must not less than k[0].";
|
||||
}
|
||||
u = std::min(u, n_ - 1);
|
||||
l = std::max(-(m_ - 1), l);
|
||||
int64_t num_diags = u - l + 1;
|
||||
// New diagonal matrix m * n matrix, n dimension
|
||||
int64_t max_diag_len =
|
||||
std::min(m_ + std::min(u, static_cast<int64_t>(0)), n_ + std::min(-l, static_cast<int64_t>(0)));
|
||||
MS_LOG(DEBUG) << "Num_diags:" << num_diags << ",max_diag_len:" << max_diag_len;
|
||||
int64_t dest_inner_matrix_size = num_diags * max_diag_len;
|
||||
out_shapes_ = shapes_;
|
||||
// Set dynamic shape and dtype
|
||||
if (!node_wpt_.expired()) {
|
||||
auto node_ = node_wpt_.lock();
|
||||
out_shapes_[shapes_.size() - kDim1] = max_diag_len;
|
||||
// If the out shape m' * n', the m' dimension is 1, then remove this dimension
|
||||
out_shapes_[shapes_.size() - kDim2] = num_diags;
|
||||
if (num_diags == 1) {
|
||||
out_shapes_.erase(out_shapes_.begin() + shapes_.size() - kDim2);
|
||||
}
|
||||
auto dtype = AnfAlgo::GetOutputDeviceDataType(node_, 0);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {out_shapes_}, node_.get());
|
||||
}
|
||||
for (int64_t i = 0; i < out_range_size_; i++) {
|
||||
// The j_index means current dest row index
|
||||
for (int64_t j = u; j >= l; j--) {
|
||||
int64_t current_diag_len = j >= 0 ? std::min(n_ - j, m_) : std::min(m_ + j, n_);
|
||||
int64_t current_pad_len = max_diag_len - current_diag_len;
|
||||
// Pad left by default
|
||||
bool pad_left = (alignment_.first == MatrixDiag::Alignment::RIGHT && j > 0) ||
|
||||
(alignment_.second == MatrixDiag::Alignment::RIGHT && j < 0);
|
||||
// Set none-padding values, l means current diag col index
|
||||
for (int64_t k = 0; k < max_diag_len; k++) {
|
||||
// Source pos, k offset, only effective when pad left
|
||||
int64_t k_offset = (pad_left && k >= current_pad_len) ? k - current_pad_len : k;
|
||||
// Calculate source offset row/col offset
|
||||
size_t row_index = j >= 0 ? j + k_offset : k_offset;
|
||||
size_t col_index = j >= 0 ? k_offset : k_offset - j;
|
||||
size_t source_offset = i * m_ * n_ + col_index * n_ + row_index;
|
||||
// If current pos need pad, then the value is pad value
|
||||
bool current_pad_flag = (pad_left && k < current_pad_len) || (!pad_left && k >= current_diag_len);
|
||||
T current_pad_value = current_pad_flag ? *padding_value : *(in_value + source_offset);
|
||||
int64_t j_index = u - j;
|
||||
size_t dest_offset = dest_inner_matrix_size * i + j_index * max_diag_len + k;
|
||||
MS_LOG(DEBUG) << "the diag j:" << j << ",k:" << k << ",k_offset:" << k_offset << ",row:" << row_index
|
||||
<< ",col:" << col_index << ",j_index:" << j_index << ",current_pad_value:" << current_pad_value;
|
||||
*(out_value + dest_offset) = current_pad_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_DIAG_PART_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_DIAG_PART_H
|
||||
#include <vector>
|
||||
#include <complex>
|
||||
#include <utility>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class MatrixDiagPartCPUKernel : public CPUKernel {
|
||||
public:
|
||||
MatrixDiagPartCPUKernel() = default;
|
||||
~MatrixDiagPartCPUKernel() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
// <Super_matrix_diag_align, Sub_matrix_diag_align>
|
||||
std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> alignment_{MatrixDiag::RIGHT, MatrixDiag::LEFT};
|
||||
std::vector<size_t> shapes_{};
|
||||
int64_t out_range_size_{1};
|
||||
size_t dim_size_{1};
|
||||
int64_t m_{1};
|
||||
int64_t n_{1};
|
||||
std::vector<size_t> out_shapes_{};
|
||||
CNodeWeakPtr node_wpt_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MatrixDiagPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MatrixDiagPartCPUKernel, int32_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MatrixDiagPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MatrixDiagPartCPUKernel, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MatrixDiagPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
|
||||
MatrixDiagPartCPUKernel, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(MatrixDiagPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MatrixDiagPartCPUKernel, double);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_DIAG_PART_H
|
|
@ -575,6 +575,7 @@ inline const PrimitivePtr kPrimAddcmul = std::make_shared<Primitive>(kAddcmul);
|
|||
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
|
||||
inline const PrimitivePtr kPrimMatMulV2 = std::make_shared<Primitive>("MatMulV2");
|
||||
inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag");
|
||||
inline const PrimitivePtr kPrimMatrixDiagPart = std::make_shared<Primitive>("MatrixDiagPart");
|
||||
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
|
||||
inline const PrimitivePtr kPrimBatchMatMulV2 = std::make_shared<Primitive>("BatchMatMulV2");
|
||||
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/matrix_diag_part.h"
|
||||
#include <set>
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const constexpr int64_t kShape2 = 2;
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_shape = input_args[0]->BuildShape();
|
||||
auto shape_element = input_shape->cast<abstract::ShapePtr>();
|
||||
ShapeVector shape = shape_element->shape();
|
||||
ShapeVector min_shape = shape_element->shape();
|
||||
ShapeVector max_shape = shape_element->shape();
|
||||
max_shape[shape.size() - 1] = kShape2 * shape[shape.size() - 1] - 1;
|
||||
min_shape[shape.size() - 1] = 1;
|
||||
shape[shape.size() - 1] = abstract::Shape::SHP_ANY;
|
||||
return std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(infer_type);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32, kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("input", infer_type, valid_types, prim->name());
|
||||
return infer_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagPart, prim::kPrimMatrixDiagPart, MatrixDiagPartInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
|
||||
#define MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMatrixDiagPart = "MatrixDiagPart";
|
||||
/// \brief get the specified part of the inner most diag matrix of a matrix, fill with padding value .
|
||||
/// Refer to Python API @ref mindspore.ops.MatrixDiagPart for more details.
|
||||
class MatrixDiagPart : public PrimitiveC {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
MatrixDiagPart() : PrimitiveC(kNameMatrixDiagPart) { InitIOName({"input", "k", "padding_value"}, {"output"}); }
|
||||
/// \brief Destructor.
|
||||
~MatrixDiagPart() = default;
|
||||
MS_DECLARE_PARENT(MatrixDiagPart, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
|
|
@ -390,4 +390,33 @@ class MatrixBandPartNet(nn.Cell):
|
|||
return self.matrix_band_part(a, num_lower, num_upper)
|
||||
|
||||
|
||||
class MatrixDiagPart(PrimitiveWithInfer):
|
||||
"""
|
||||
MatrixDiagPart()
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, align="RIGHT_LEFT"):
|
||||
super().__init__(name="MatrixDiagPart")
|
||||
self.add_prim_attr('alignment', align)
|
||||
self.init_prim_io_names(inputs=['A', 'k', 'padding_value'], outputs=['output'])
|
||||
|
||||
|
||||
class MatrixDiagPartNet(nn.Cell):
|
||||
"""
|
||||
Returns:
|
||||
batched diagonal part of a batched tensor, the part between, k[0] to k[1], the shape is dynamic
|
||||
|
||||
Raises:
|
||||
k[1] should not less then k[0]
|
||||
"""
|
||||
|
||||
def __init__(self, align="RIGHT_LEFT"):
|
||||
super(MatrixDiagPartNet, self).__init__()
|
||||
self.matrix_diag_part = MatrixDiagPart(align)
|
||||
|
||||
def construct(self, a, k, padding_value):
|
||||
return self.matrix_diag_part(a, k, padding_value)
|
||||
|
||||
|
||||
from .ops_grad import get_bprpo_eigh
|
||||
|
|
|
@ -0,0 +1,204 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""st for scipy.ops_wrapper."""
|
||||
import pytest
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.scipy.ops import MatrixDiagPartNet
|
||||
from tests.st.scipy_st.utils import match_matrix
|
||||
|
||||
aligndict = {0: "LEFT_RIGHT", 1: "LEFT_LEFT", 2: "RIGHT_LEFT", 3: "RIGHT_RIGHT"}
|
||||
PAD_VALUE = -1
|
||||
Adict = {(1, 1, 1): (([[[1]]]), {}), (1, 3, 3): (([[[8, 2, 1], [5, 3, 7], [0, 3, 4]]]),
|
||||
{(-2, -2, 0): ([[0]]), (-2, -1, 3): ([[[5, 3], [-1, 0]]]),
|
||||
(-2, 0, 2): ([[[8, 3, 4], [5, 3, -1], [0, -1, -1]]]),
|
||||
(-2, 1, 3): ([[[-1, 2, 7], [8, 3, 4], [-1, 5, 3], [-1, -1, 0]]]),
|
||||
(-2, 2, 0): (\
|
||||
[[[1, -1, -1], [2, 7, -1], [8, 3, 4], [-1, 5, 3], [-1, -1, 0]]]),
|
||||
(-1, -1, 2): ([[5, 3]]), (-1, 0, 1): ([[[8, 3, 4], [5, 3, -1]]]),
|
||||
(-1, 1, 2): ([[[-1, 2, 7], [8, 3, 4], [5, 3, -1]]]),
|
||||
(-1, 2, 3): ([[[-1, -1, 1], [-1, 2, 7], [8, 3, 4], [-1, 5, 3]]]),
|
||||
(0, 0, 0): ([[8, 3, 4]]), (0, 1, 1): ([[[2, 7, -1], [8, 3, 4]]]),
|
||||
(0, 2, 2): ([[[-1, -1, 1], [-1, 2, 7], [8, 3, 4]]]),
|
||||
(1, 1, 2): ([[2, 7]]), (1, 2, 3): ([[[-1, 1], [2, 7]]])}),
|
||||
(1, 1, 2): (([[[3, 2]]]), {}), (1, 3, 5): (([[[3, 5, 5, 2, 5], [0, 6, 2, 4, 7], [7, 3, 3, 6, 8]]]),
|
||||
{(-2, -2, 0): ([[7]]), (-2, -1, 3): ([[[0, 3], [-1, 7]]]),
|
||||
(-2, 0, 2): ([[[3, 6, 3], [0, 3, -1], [7, -1, -1]]]),
|
||||
(-2, 1, 3): ([[[5, 2, 6], [3, 6, 3], [-1, 0, 3], [-1, -1, 7]]]),
|
||||
(-2, 2, 0): (\
|
||||
[[[5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3], [-1, -1, 7]]]),
|
||||
(-2, 3, 1): ([
|
||||
[[2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], [0, 3, -1],
|
||||
[7, -1, -1]]]), (-2, 4, 2): ([\
|
||||
[[-1, -1, 5], [-1, 2, 7], [5, 4, 8], [5, 2, 6], [3, 6, 3],
|
||||
[0, 3, -1], [7, -1, -1]]]), (-1, -1, 2): ([[0, 3]]),
|
||||
(-1, 0, 1): ([[[3, 6, 3], [0, 3, -1]]]),
|
||||
(-1, 1, 2): ([[[5, 2, 6], [3, 6, 3], [0, 3, -1]]]),
|
||||
(-1, 2, 3): ([[[5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3]]]),
|
||||
(-1, 3, 0): (\
|
||||
[[[2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3]]]),
|
||||
(-1, 4, 1): ([
|
||||
[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3],
|
||||
[0, 3, -1]]]), (0, 0, 0): ([[3, 6, 3]]),
|
||||
(0, 1, 1): ([[[5, 2, 6], [3, 6, 3]]]),
|
||||
(0, 2, 2): ([[[5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
|
||||
(0, 3, 3): ([[[-1, 2, 7], [5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
|
||||
(0, 4, 0): (\
|
||||
[[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3]]]),
|
||||
(1, 1, 2): ([[5, 2, 6]]), (1, 2, 3): ([[[5, 4, 8], [5, 2, 6]]]),
|
||||
(1, 3, 0): ([[[2, 7, -1], [5, 4, 8], [5, 2, 6]]]),
|
||||
(1, 4, 1): ([[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6]]])}),
|
||||
(1, 2, 1): (([[[4], [1]]]), {(-1, -1, 2): ([[1]]), (-1, 0, 1): ([[[4], [1]]]), (0, 0, 0): ([[4]])}),
|
||||
(1, 5, 3): (([[[7, 8, 8], [3, 5, 6], [0, 4, 4], [8, 4, 5], [0, 4, 6]]]),
|
||||
{(-4, -4, 0): ([[0]]), (-4, -3, 3): ([[[8, 4], [-1, 0]]]),
|
||||
(-4, -2, 2): ([[[0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
|
||||
(-4, -1, 1): ([[[3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
|
||||
(-4, 0, 0): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4], [-1, -1, 0]]]),
|
||||
(-4, 1, 1): ([[[8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
|
||||
(-4, 2, 2): (\
|
||||
[[[-1, -1, 8], [-1, 8, 6], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]),
|
||||
(-3, -3, 2): ([[8, 4]]), (-3, -2, 1): ([[[0, 4, 6], [8, 4, -1]]]),
|
||||
(-3, -1, 0): ([[[3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
|
||||
(-3, 0, 3): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
|
||||
(-3, 1, 0): ([[[8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4]]]),
|
||||
(-3, 2, 1): ([[[8, -1, -1], [8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1]]]),
|
||||
(-2, -2, 0): ([[0, 4, 6]]), (-2, -1, 3): ([[[3, 4, 5], [0, 4, 6]]]),
|
||||
(-2, 0, 2): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
|
||||
(-2, 1, 3): ([[[-1, 8, 6], [7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
|
||||
(-2, 2, 0): ([[[8, -1, -1], [8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6]]]),
|
||||
(-1, -1, 2): ([[3, 4, 5]]), (-1, 0, 1): ([[[7, 5, 4], [3, 4, 5]]]),
|
||||
(-1, 1, 2): ([[[-1, 8, 6], [7, 5, 4], [3, 4, 5]]]),
|
||||
(-1, 2, 3): ([[[-1, -1, 8], [-1, 8, 6], [7, 5, 4], [3, 4, 5]]]), (0, 0, 0): ([[7, 5, 4]]),
|
||||
(0, 1, 1): ([[[8, 6, -1], [7, 5, 4]]]), (0, 2, 2): ([[[-1, -1, 8], [-1, 8, 6], [7, 5, 4]]]),
|
||||
(1, 1, 2): ([[8, 6]]), (1, 2, 3): ([[[-1, 8], [8, 6]]]), (2, 2, 0): ([[8]])}),
|
||||
(2, 1, 1): (([[[5]], [[6]]]), {}), (2, 3, 3): (\
|
||||
([[[7, 6, 3], [5, 8, 5], [5, 0, 2]], [[1, 8, 1], [5, 5, 8], [8, 4, 0]]]),
|
||||
{(-2, -2, 0): ([[5], [8]]), (-2, -1, 3): ([[[5, 0], [-1, 5]], [[5, 4], [-1, 8]]]),
|
||||
(-2, 0, 2): ([[[7, 8, 2], [5, 0, -1], [5, -1, -1]], [[1, 5, 0], [5, 4, -1], [8, -1, -1]]]),
|
||||
(-2, 1, 3): ([[[-1, 6, 5], [7, 8, 2], [-1, 5, 0], [-1, -1, 5]], [[-1, 8, 8], [1, 5, 0], [-1, 5, 4], [-1, -1, 8]]]),
|
||||
(-2, 2, 0): ([[[3, -1, -1], [6, 5, -1], [7, 8, 2], [-1, 5, 0], [-1, -1, 5]],
|
||||
[[1, -1, -1], [8, 8, -1], [1, 5, 0], [-1, 5, 4], [-1, -1, 8]]]), (-1, -1, 2): ([[5, 0], [5, 4]]),
|
||||
(-1, 0, 1): ([[[7, 8, 2], [5, 0, -1]], [[1, 5, 0], [5, 4, -1]]]),
|
||||
(-1, 1, 2): ([[[-1, 6, 5], [7, 8, 2], [5, 0, -1]], [[-1, 8, 8], [1, 5, 0], [5, 4, -1]]]),
|
||||
(-1, 2, 3): ([[[-1, -1, 3], [-1, 6, 5], [7, 8, 2], [-1, 5, 0]], [[-1, -1, 1], [-1, 8, 8], [1, 5, 0], [-1, 5, 4]]]),
|
||||
(0, 0, 0): ([[7, 8, 2], [1, 5, 0]]), (0, 1, 1): ([[[6, 5, -1], [7, 8, 2]], [[8, 8, -1], [1, 5, 0]]]),
|
||||
(0, 2, 2): ([[[-1, -1, 3], [-1, 6, 5], [7, 8, 2]], [[-1, -1, 1], [-1, 8, 8], [1, 5, 0]]]),
|
||||
(1, 1, 2): ([[6, 5], [8, 8]]), (1, 2, 3): ([[[-1, 3], [6, 5]], [[-1, 1], [8, 8]]])}),
|
||||
(2, 1, 2): (([[[6, 3]], [[5, 5]]]), {}), (2, 3, 5): (\
|
||||
([[[1, 2, 1, 2, 7], [0, 3, 5, 0, 2], [0, 5, 1, 7, 5]], [[3, 4, 3, 5, 7], [2, 5, 2, 7, 5], [7, 5, 1, 1, 7]]]),
|
||||
{(-2, -2, 0): ([[0], [7]]), (-2, -1, 3): ([[[0, 5], [-1, 0]], [[2, 5], [-1, 7]]]),
|
||||
(-2, 0, 2): ([[[1, 3, 1], [0, 5, -1], [0, -1, -1]], [[3, 5, 1], [2, 5, -1], [7, -1, -1]]]),
|
||||
(-2, 1, 3): ([[[2, 5, 7], [1, 3, 1], [-1, 0, 5], [-1, -1, 0]], [[4, 2, 1], [3, 5, 1], [-1, 2, 5], [-1, -1, 7]]]),
|
||||
(-2, 2, 0): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5], [-1, -1, 0]],
|
||||
[[3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5], [-1, -1, 7]]]), (-2, 3, 1): (\
|
||||
[[[2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1], [0, -1, -1]],
|
||||
[[5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1], [7, -1, -1]]]), (-2, 4, 2): (\
|
||||
[[[-1, -1, 7], [-1, 2, 2], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1], [0, -1, -1]],
|
||||
[[-1, -1, 7], [-1, 5, 5], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1], [7, -1, -1]]]),
|
||||
(-1, -1, 2): ([[0, 5], [2, 5]]), (-1, 0, 1): ([[[1, 3, 1], [0, 5, -1]], [[3, 5, 1], [2, 5, -1]]]),
|
||||
(-1, 1, 2): ([[[2, 5, 7], [1, 3, 1], [0, 5, -1]], [[4, 2, 1], [3, 5, 1], [2, 5, -1]]]),
|
||||
(-1, 2, 3): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5]], [[3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5]]]),
|
||||
(-1, 3, 0): ([[[2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5]],
|
||||
[[5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5]]]), (-1, 4, 1): (\
|
||||
[[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1]],
|
||||
[[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1]]]), (0, 0, 0): ([[1, 3, 1], [3, 5, 1]]),
|
||||
(0, 1, 1): ([[[2, 5, 7], [1, 3, 1]], [[4, 2, 1], [3, 5, 1]]]),
|
||||
(0, 2, 2): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1]], [[3, 7, 7], [4, 2, 1], [3, 5, 1]]]),
|
||||
(0, 3, 3): ([[[-1, 2, 2], [1, 0, 5], [2, 5, 7], [1, 3, 1]], [[-1, 5, 5], [3, 7, 7], [4, 2, 1], [3, 5, 1]]]),
|
||||
(0, 4, 0): ([[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1]],
|
||||
[[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1]]]), (1, 1, 2): ([[2, 5, 7], [4, 2, 1]]),
|
||||
(1, 2, 3): ([[[1, 0, 5], [2, 5, 7]], [[3, 7, 7], [4, 2, 1]]]),
|
||||
(1, 3, 0): ([[[2, 2, -1], [1, 0, 5], [2, 5, 7]], [[5, 5, -1], [3, 7, 7], [4, 2, 1]]]),
|
||||
(1, 4, 1): ([[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7]], [[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1]]])}),
|
||||
(2, 2, 1): (([[[4], [8]], [[3], [5]]]),
|
||||
{(-1, -1, 2): ([[8], [5]]), (-1, 0, 1): ([[[4], [8]], [[3], [5]]]), (0, 0, 0): ([[4], [3]])}),
|
||||
(2, 5, 3): (([[[6, 8, 5], [7, 2, 7], [2, 2, 5], [5, 6, 7], [5, 0, 2]],
|
||||
[[3, 8, 7], [7, 8, 2], [8, 1, 0], [0, 6, 5], [6, 3, 1]]]),
|
||||
{(-4, -4, 0): ([[5], [6]]), (-4, -3, 3): ([[[5, 0], [-1, 5]], [[0, 3], [-1, 6]]]),
|
||||
(-4, -2, 2): ([[[2, 6, 2], [5, 0, -1], [5, -1, -1]], [[8, 6, 1], [0, 3, -1], [6, -1, -1]]]),
|
||||
(-4, -1, 1): ([[[7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
|
||||
[[7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), (-4, 0, 0): (\
|
||||
[[[6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0], [-1, -1, 5]],
|
||||
[[3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3], [-1, -1, 6]]]), (-4, 1, 1): (\
|
||||
[[[8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
|
||||
[[8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), (-4, 2, 2): (\
|
||||
[[[-1, -1, 5], [-1, 8, 7], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]],
|
||||
[[-1, -1, 7], [-1, 8, 2], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]),
|
||||
(-3, -3, 2): ([[5, 0], [0, 3]]),
|
||||
(-3, -2, 1): ([[[2, 6, 2], [5, 0, -1]], [[8, 6, 1], [0, 3, -1]]]),
|
||||
(-3, -1, 0): ([[[7, 2, 7], [2, 6, 2], [-1, 5, 0]], [[7, 1, 5], [8, 6, 1], [-1, 0, 3]]]),
|
||||
(-3, 0, 3): (\
|
||||
[[[6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0]], [[3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3]]]),
|
||||
(-3, 1, 0): ([[[8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0]],
|
||||
[[8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3]]]), (-3, 2, 1): (\
|
||||
[[[5, -1, -1], [8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1]],
|
||||
[[7, -1, -1], [8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1]]]),
|
||||
(-2, -2, 0): ([[2, 6, 2], [8, 6, 1]]),
|
||||
(-2, -1, 3): ([[[7, 2, 7], [2, 6, 2]], [[7, 1, 5], [8, 6, 1]]]),
|
||||
(-2, 0, 2): ([[[6, 2, 5], [7, 2, 7], [2, 6, 2]], [[3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
|
||||
(-2, 1, 3): (\
|
||||
[[[-1, 8, 7], [6, 2, 5], [7, 2, 7], [2, 6, 2]], [[-1, 8, 2], [3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
|
||||
(-2, 2, 0): ([[[5, -1, -1], [8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2]],
|
||||
[[7, -1, -1], [8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1]]]),
|
||||
(-1, -1, 2): ([[7, 2, 7], [7, 1, 5]]),
|
||||
(-1, 0, 1): ([[[6, 2, 5], [7, 2, 7]], [[3, 8, 0], [7, 1, 5]]]),
|
||||
(-1, 1, 2): ([[[-1, 8, 7], [6, 2, 5], [7, 2, 7]], [[-1, 8, 2], [3, 8, 0], [7, 1, 5]]]),
|
||||
(-1, 2, 3): ([[[-1, -1, 5], [-1, 8, 7], [6, 2, 5], [7, 2, 7]],
|
||||
[[-1, -1, 7], [-1, 8, 2], [3, 8, 0], [7, 1, 5]]]),
|
||||
(0, 0, 0): ([[6, 2, 5], [3, 8, 0]]),
|
||||
(0, 1, 1): ([[[8, 7, -1], [6, 2, 5]], [[8, 2, -1], [3, 8, 0]]]),
|
||||
(0, 2, 2): ([[[-1, -1, 5], [-1, 8, 7], [6, 2, 5]], [[-1, -1, 7], [-1, 8, 2], [3, 8, 0]]]),
|
||||
(1, 1, 2): ([[8, 7], [8, 2]]), (1, 2, 3): ([[[-1, 5], [8, 7]], [[-1, 7], [8, 2]]]),
|
||||
(2, 2, 0): ([[5], [7]])})}
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_matrix_diag_part_net_cpu():
|
||||
"""
|
||||
testcase generate from below
|
||||
from tensorflow.python.ops import array_ops
|
||||
import numpy as np
|
||||
f = open (r'dict.tst','w')
|
||||
aligndict = {0: "LEFT_RIGHT", 1:"LEFT_LEFT", 2:"RIGHT_LEFT", 3:"RIGHT_RIGHT"}
|
||||
Adict={}
|
||||
for i in [1, 2]:
|
||||
for m,n in [(1, 1), (3,3),(1, 2),(3, 5),(2, 1),(5, 3)]:
|
||||
A = np.array(np.random.randint(20, size=(i, m, n)))
|
||||
Adict[i,m,n]=A
|
||||
p = -1
|
||||
kadict={}
|
||||
for k0 in range(-m + 1, m - 1):
|
||||
for k1 in range(k0, n):
|
||||
k = (k0, k1)
|
||||
align_= (abs(k0)+ abs(k1)) % 4
|
||||
ka = (k,align_)
|
||||
B = array_ops.matrix_diag_part(A, k=k, align=aligndict[align_], padding_value=-1)
|
||||
kadict[ka] = B.numpy()
|
||||
Adict[i,m,n]=(A, kadict)
|
||||
print(Adict, file= f)
|
||||
f.close()
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for _, value in Adict.items():
|
||||
a, kadict = value
|
||||
for key1, b in kadict.items():
|
||||
k0, k1, align_ = key1
|
||||
msp_matrixdiagpart = MatrixDiagPartNet(align=aligndict[align_])
|
||||
r_b = msp_matrixdiagpart(Tensor(a), Tensor([k0, k1]), Tensor(PAD_VALUE))
|
||||
match_matrix(Tensor(b), Tensor(r_b))
|
Loading…
Reference in New Issue