!25339 [assistant][ops] Add new Tril

Merge pull request !25339 from 张凯磊/Tril
This commit is contained in:
i-robot 2022-03-28 09:26:13 +00:00 committed by Gitee
commit c2e0796bdb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 421 additions and 0 deletions

View File

@ -0,0 +1,154 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/tril_cpu_kernel.h"
#include <algorithm>
#include "Eigen/Core"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kTrilInputsNum = 1;
constexpr size_t kTrilOutputsNum = 1;
constexpr size_t kDim = 2;
} // namespace
void TrilCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
input_dims_ = input_shape_.size();
if (input_dims_ < kDim) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got "
<< input_dims_ << "-D.";
}
if (common::AnfAlgo::HasNodeAttr("diagonal", kernel_node)) {
diagonal_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "diagonal");
}
}
bool TrilCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTrilInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTrilOutputsNum, kernel_name_);
switch (dtype_) {
case (kNumberTypeUInt8):
LaunchKernel<uint8_t>(inputs, outputs);
break;
case (kNumberTypeUInt16):
LaunchKernel<uint16_t>(inputs, outputs);
break;
case (kNumberTypeUInt32):
LaunchKernel<uint32_t>(inputs, outputs);
break;
case (kNumberTypeUInt64):
LaunchKernel<uint64_t>(inputs, outputs);
break;
case (kNumberTypeInt8):
LaunchKernel<int8_t>(inputs, outputs);
break;
case (kNumberTypeInt16):
LaunchKernel<int16_t>(inputs, outputs);
break;
case (kNumberTypeInt32):
LaunchKernel<int32_t>(inputs, outputs);
break;
case (kNumberTypeInt64):
LaunchKernel<int64_t>(inputs, outputs);
break;
case (kNumberTypeFloat16):
LaunchKernel<float16>(inputs, outputs);
break;
case (kNumberTypeFloat32):
LaunchKernel<float>(inputs, outputs);
break;
case (kNumberTypeFloat64):
LaunchKernel<double>(inputs, outputs);
break;
case (kNumberTypeBool):
LaunchKernel<bool>(inputs, outputs);
break;
default:
MS_LOG(EXCEPTION) << "the datatype of the input not support, support datatype: "
"uint8, uint16, uint32, uint64, int8, int16, int32, int64, "
"float16, float32, float64, bool.";
}
return true;
}
template <typename T>
void TrilCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t input_size = 1;
for (size_t i = 0; i < input_dims_; ++i) {
input_size *= input_shape_[i];
}
using MatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
auto matrix_width = input_shape_[input_dims_ - 2];
auto matrix_height = input_shape_[input_dims_ - 1];
auto matrix_size = matrix_width * matrix_height;
auto matrixs_num = input_size / matrix_size;
for (size_t k = 0; k < matrixs_num; ++k) {
MatrixMap input(input_addr + k * matrix_size, matrix_width, matrix_height);
MatrixMap output(output_addr + k * matrix_size, matrix_width, matrix_height);
output = input.template triangularView<Eigen::Lower>();
if (diagonal_ > 0) {
for (size_t i = 0; i < matrix_width; i++) {
for (size_t j = i + 1; j <= i + diagonal_ && j < matrix_height; j++) {
output(i, j) = input(i, j);
}
}
} else {
for (size_t j = 0; j < matrix_height; j++) {
for (size_t i = j; i < j - diagonal_ && i < matrix_width; i++) {
output(i, j) = static_cast<T>(0.0);
}
}
}
}
}
std::vector<KernelAttr> TrilCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Tril, TrilCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_TRIL_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_TRIL_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class TrilCpuKernelMod : public NativeCpuKernelMod {
public:
TrilCpuKernelMod() = default;
~TrilCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
int64_t diagonal_{0};
std::vector<size_t> input_shape_;
size_t input_dims_;
TypeId dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_TRIL_CPU_KERNEL_H_

View File

@ -132,6 +132,7 @@ constexpr auto kLstsq = "Lstsq";
constexpr auto kLowerBound = "LowerBound";
constexpr auto kUpperBound = "UpperBound";
constexpr auto kCummax = "Cummax";
constexpr auto kTril = "Tril";
// NN
constexpr auto kCTCLoss = "CTCLoss";
@ -372,6 +373,7 @@ GVAR_DEF(PrimitivePtr, kPrimLstsq, std::make_shared<Primitive>(kLstsq));
GVAR_DEF(PrimitivePtr, kPrimLowerBound, std::make_shared<Primitive>(kLowerBound));
GVAR_DEF(PrimitivePtr, kPrimUpperBound, std::make_shared<Primitive>(kUpperBound));
GVAR_DEF(PrimitivePtr, kPrimCummax, std::make_shared<Primitive>(kCummax));
GVAR_DEF(PrimitivePtr, kPrimTril, std::make_shared<Primitive>(kTril));
// image
GVAR_DEF(PrimitivePtr, kPrimCropAndResizeGradBoxes, std::make_shared<Primitive>(kCropAndResizeGradBoxes));

View File

@ -0,0 +1,67 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/tril.h"
#include <algorithm>
#include <set>
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr TrilInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t kShapeSize = 2;
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("x's rank", x_shape.size(), kGreaterEqual, kShapeSize, prim_name);
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr TrilInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type = input_args[0]->BuildType();
std::set<TypePtr> valid_x_types(common_valid_types);
(void)valid_x_types.emplace(kBool);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_x_types, prim_name);
return x_type;
}
} // namespace
AbstractBasePtr TrilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
auto infer_type = TrilInferType(primitive, input_args);
auto infer_shape = TrilInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_BASE_IMPL(Tril, PrimitiveC, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(Tril, prim::kPrimTril, TrilInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

42
mindspore/core/ops/tril.h Normal file
View File

@ -0,0 +1,42 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TRIL_H_
#define MINDSPORE_CORE_OPS_TRIL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameTril = "Tril";
class MIND_API Tril : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Tril);
Tril() : BaseOperator(kNameTril) { InitIOName({"x"}, {"y"}); }
};
abstract::AbstractBasePtr TrilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TRIL_H_

View File

@ -19,6 +19,7 @@ from ...common import dtype as mstype
from .._grad.grad_math_ops import binop_grad_common
from .._grad.grad_base import bprop_getters
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.array_ops import Tril
from .. import functional as F
from .. import operations as P
@ -167,3 +168,16 @@ def get_bprop_extract_volume_patches(self):
return (dx,)
return bprop
@bprop_getters.register(Tril)
def get_bprop_tril(self):
"""Grad definition for 'Tril' operation"""
diagonal = self.diagonal
tril = Tril(diagonal)
def bprop(x, out, dout):
dx = tril(dout)
return (dx,)
return bprop

View File

@ -133,3 +133,4 @@ from .priority_replay_buffer import _prb_create_op_cpu
from .priority_replay_buffer import _prb_push_op_cpu
from .priority_replay_buffer import _prb_sample_op_cpu
from .priority_replay_buffer import _prb_update_op_cpu
from .tril import _tril_aicpu

View File

@ -0,0 +1,42 @@
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tril op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
tril_op_info = AiCPURegOp("Tril") \
.fusion_type("OPAQUE") \
.attr("diagonal", "int") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()
@op_info_register(tril_op_info)
def _tril_aicpu():
"""Tril AiCPU register"""
return

View File

@ -7272,3 +7272,43 @@ class Cummax(Primitive):
"""Initialize Cummax"""
validator.check_value_type("dim", dim, [int], self.name)
self.init_prim_io_names(inputs=['x'], outputs=['y', 'indices'])
class Tril(Primitive):
"""
Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input,
the other elements of the result tensor out are set to 0.
The lower triangular part of the matrix is defined as the elements on and below the diagonal.
Args:
diagonal (int): An optional attribute indicates the diagonal to consider, default to 0.
Inputs:
- **x** (Tensor) - A Tensor with shape :math:`(x_1, x_2, ..., x_R)`. The rank must be at least 2.
Supporting all number types including bool.
Outputs:
Tensor, the same shape and data type as the input.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If `diagonal` is not an int.
TypeError: If the type of `x` is neither number nor bool.
ValueError: If the rank of `x` is less than 2.
Supported Platforms:
``CPU``
Examples:
>>> tril = ops.Tril()
>>> output = tril(Tensor(np.array([[-13.5383, 2.5474, ], [-5.7496, -3.4548]]), mindspore.float32))
>>> print(output)
[[ -13.5383 0. ]
[ -5.7496 -3.4548]]
"""
@prim_attr_register
def __init__(self, diagonal=0):
"""Initialize Tril."""
self.init_prim_io_names(inputs=["x"], outputs=["y"])
validator.check_value_type("diagonal", diagonal, [int], self.name)

View File

@ -31,6 +31,7 @@ from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import _quant_ops as Q
from mindspore.ops.operations import nn_ops as nps
from mindspore.ops.operations.array_ops import Tril
from mindspore.ops.operations.random_ops import NonDeterministicInts
from mindspore.nn.layer import normalization
from mindspore._c_expression import security
@ -2823,6 +2824,11 @@ test_case_array_ops = [
'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])],
'skip': ['backward'],
}),
('Tril', {
'block': Tril(),
'desc_inputs': [Tensor(np.random.rand(3, 8, 9), mstype.float32)],
'desc_brop': [Tensor(np.random.rand(5, 6, 6), mstype.float32)]
}),
]
test_case_image_ops = [