!30266 [assistant] [ops] [I4CRJJ] add new array operator Triu

Merge pull request !30266 from ruili/Triu
This commit is contained in:
i-robot 2022-03-30 03:00:48 +00:00 committed by Gitee
commit b67ba47446
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 433 additions and 3 deletions

View File

@ -0,0 +1,140 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Eigen/Dense>
#include <algorithm>
#include "plugin/device/cpu/kernel/triu_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kTriuInputsNum = 1;
constexpr size_t kTriuOutputsNum = 1;
constexpr size_t kDim = 2;
} // namespace
void TriuCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
input_dims_ = input_shape_.size();
if (input_dims_ < kDim) {
MS_LOG(EXCEPTION)
<< "For Triu, the input tensor's rank must be at least 2 for 'Triu' Op, but input tensor's rank is "
<< input_dims_ << ".";
}
if (common::AnfAlgo::HasNodeAttr("diagonal", kernel_node)) {
diagonal_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "diagonal");
}
}
template <typename T>
bool TriuCpuKernelMod::TriuCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTriuInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTriuOutputsNum, kernel_name_);
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t input_size = 1;
for (size_t i = 0; i < input_dims_; ++i) {
input_size *= input_shape_[i];
}
using MatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
auto matrix_width = input_shape_[input_dims_ - 2];
auto matrix_height = input_shape_[input_dims_ - 1];
auto matrix_size = matrix_width * matrix_height;
auto matrixs_num = input_size / matrix_size;
for (size_t k = 0; k < matrixs_num; ++k) {
MatrixMap input(input_addr + k * matrix_size, matrix_width, matrix_height);
MatrixMap output(output_addr + k * matrix_size, matrix_width, matrix_height);
output = input.template triangularView<Eigen::Upper>();
if (diagonal_ < 0) {
for (size_t j = 0; j < matrix_height; j++) {
for (size_t i = j + 1; i <= j - diagonal_ && i < matrix_width; i++) {
output(i, j) = input(i, j);
}
}
} else {
for (size_t i = 0; i < matrix_width; i++) {
for (size_t j = i; j < i + diagonal_ && j < matrix_height; j++) {
output(i, j) = static_cast<T>(0.0);
}
}
}
}
return true;
}
bool TriuCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
switch (input_dtype_) {
case kNumberTypeUInt8:
return TriuCompute<uint8_t>(inputs, outputs);
case kNumberTypeUInt16:
return TriuCompute<uint16_t>(inputs, outputs);
case kNumberTypeUInt32:
return TriuCompute<uint32_t>(inputs, outputs);
case kNumberTypeUInt64:
return TriuCompute<uint64_t>(inputs, outputs);
case kNumberTypeInt8:
return TriuCompute<int8_t>(inputs, outputs);
case kNumberTypeInt16:
return TriuCompute<int16_t>(inputs, outputs);
case kNumberTypeInt32:
return TriuCompute<int32_t>(inputs, outputs);
case kNumberTypeInt64:
return TriuCompute<int64_t>(inputs, outputs);
case kNumberTypeFloat16:
return TriuCompute<float16>(inputs, outputs);
case kNumberTypeFloat32:
return TriuCompute<float>(inputs, outputs);
case kNumberTypeFloat64:
return TriuCompute<double>(inputs, outputs);
case kNumberTypeBool:
return TriuCompute<bool>(inputs, outputs);
default:
MS_LOG(ERROR) << "Unsupported data type.";
}
return true;
}
std::vector<KernelAttr> TriuCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Triu, TriuCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRIU_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRIU_CPU_KERNEL_H_
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class TriuCpuKernelMod : public NativeCpuKernelMod {
public:
TriuCpuKernelMod() = default;
~TriuCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
int64_t diagonal_{0};
std::vector<size_t> input_shape_;
size_t input_dims_;
TypeId input_dtype_{kTypeUnknown};
template <typename T>
bool TriuCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRIU_CPU_KERNEL_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -134,6 +134,7 @@ constexpr auto kLowerBound = "LowerBound";
constexpr auto kUpperBound = "UpperBound";
constexpr auto kCummax = "Cummax";
constexpr auto kTril = "Tril";
constexpr auto kTriu = "Triu";
// NN
constexpr auto kCTCLoss = "CTCLoss";
@ -376,6 +377,7 @@ GVAR_DEF(PrimitivePtr, kPrimUpperBound, std::make_shared<Primitive>(kUpperBound)
GVAR_DEF(PrimitivePtr, kPrimCummax, std::make_shared<Primitive>(kCummax));
GVAR_DEF(PrimitivePtr, kPrimRightShift, std::make_shared<Primitive>(kRightShift));
GVAR_DEF(PrimitivePtr, kPrimTril, std::make_shared<Primitive>(kTril));
GVAR_DEF(PrimitivePtr, kPrimTriu, std::make_shared<Primitive>(kTriu));
// image
GVAR_DEF(PrimitivePtr, kPrimCropAndResizeGradBoxes, std::make_shared<Primitive>(kCropAndResizeGradBoxes));

View File

@ -0,0 +1,63 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/triu.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr TriuInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto x = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
TypePtr TriuInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto x_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32,
kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kBool};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
return x_type;
}
} // namespace
AbstractBasePtr TriuInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 1;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infertype = TriuInferType(primitive, input_args);
auto infershape = TriuInferShape(primitive, input_args);
return abstract::MakeAbstract(infershape, infertype);
}
MIND_API_BASE_IMPL(Triu, PrimitiveC, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(Triu, prim::kPrimTriu, TriuInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

47
mindspore/core/ops/triu.h Normal file
View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TRIU_H_
#define MINDSPORE_CORE_OPS_TRIU_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameTriu = "Triu";
/// \brief a tensor with elements below the kth diagonal zeroed.
/// Refer to Python API @ref mindspore.ops.Triu for more details.
class MIND_API Triu : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Triu);
/// \brief Constructor.
Triu() : BaseOperator(kNameTriu) { InitIOName({"x"}, {"y"}); }
/// \brief Init.
void Init() {}
};
AbstractBasePtr TriuInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimTriuPtr = std::shared_ptr<Triu>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TRIU_H_

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ from .._grad.grad_math_ops import binop_grad_common
from .._grad.grad_base import bprop_getters
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.array_ops import Tril
from ..operations.array_ops import Triu
from .. import functional as F
from .. import operations as P
@ -106,6 +107,19 @@ def get_bprop_coalesce(self):
return bprop
@bprop_getters.register(Triu)
def get_bprop_triu(self):
"""Grad definition for 'Triu' operation"""
diagonal = self.diagonal
triu = Triu(diagonal)
def bprop(x, out, dout):
dx = triu(dout)
return (dx,)
return bprop
@bprop_getters.register(P.SplitV)
def get_bprop_split_v(self):
"""Generate bprop for SplitV"""

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -135,3 +135,4 @@ from .priority_replay_buffer import _prb_sample_op_cpu
from .priority_replay_buffer import _prb_update_op_cpu
from .right_shift import _right_shift_aicpu
from .tril import _tril_aicpu
from .triu import _triu_aicpu

View File

@ -0,0 +1,43 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Triu op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
triu_op_info = AiCPURegOp("Triu") \
.fusion_type("OPAQUE") \
.attr("diagonal", "int") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(triu_op_info)
def _triu_aicpu():
"""Triu aicpu register"""
return

View File

@ -4803,6 +4803,71 @@ class ScatterSub(_ScatterOpDynamic):
self.add_prim_attr('side_effect_mem', True)
class Triu(Primitive):
"""
Returns a tensor with elements below the kth diagonal zeroed.
Args:
diagonal (int): The index of diagonal. Default: 0
Inputs:
- **x** (Tensor) - The input tensor. The data type is Number. (N,)
where means, any number of additional dimensions.
Outputs:
- **y** (Tensor) - A tensor has the same shape and data type as input.
Raises:
TypeError: If `diagonal` is not an int.
TypeError: If `x` is not an Tensor.
ValueError: If length of shape of x is less than 1.
Supported Platforms:
``CPU``
Examples:
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
... [ 5, 6, 7, 8],
... [10, 11, 12, 13],
... [14, 15, 16, 17]]))
>>> triu = P.Triu()
>>> result = triu(x)
>>> print(result)
[[ 1 2 3 4]
[ 0 6 7 8]
[ 0 0 12 13]
[ 0 0 0 17]]
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
... [ 5, 6, 7, 8],
... [10, 11, 12, 13],
... [14, 15, 16, 17]]))
>>> triu = P.Triu(diagonal=1)
>>> result = triu(x)
>>> print(result)
[[ 0 2 3 4]
[ 0 0 7 8]
[ 0 0 0 13]
[ 0 0 0 0]]
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
... [ 5, 6, 7, 8],
... [10, 11, 12, 13],
... [14, 15, 16, 17]]))
>>> triu = P.Triu(diagonal=-1)
>>> result = triu(x)
>>> print(result)
[[ 1 2 3 4]
[ 5 6 7 8]
[ 0 11 12 13]
[ 0 0 16 17]]
"""
@prim_attr_register
def __init__(self, diagonal=0):
"""Initialize Stack"""
validator.check_value_type("diagonal", diagonal, [int], self.name)
self.diagonal = diagonal
self.init_prim_io_names(inputs=['x'], outputs=['y'])
class ScatterMul(_ScatterOp):
r"""
Updates the value of the input tensor through the multiply operation.

View File

@ -33,6 +33,7 @@ from mindspore.ops.operations import _quant_ops as Q
from mindspore.ops.operations import nn_ops as nps
from mindspore.ops.operations.array_ops import Tril
from mindspore.ops.operations.random_ops import NonDeterministicInts
from mindspore.ops.operations.array_ops import Triu
from mindspore.nn.layer import normalization
from mindspore.ops.operations.array_ops import RightShift
from mindspore._c_expression import security
@ -2835,6 +2836,11 @@ test_case_array_ops = [
'desc_inputs': [Tensor(np.random.rand(3, 8, 9), mstype.float32)],
'desc_brop': [Tensor(np.random.rand(5, 6, 6), mstype.float32)]
}),
('Triu', {
'block': Triu(),
'desc_inputs': [Tensor(np.random.rand(3, 8, 9), mstype.float32)],
'desc_brop': [Tensor(np.random.rand(5, 6, 6), mstype.float32)]
}),
]
test_case_image_ops = [