diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/tril_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/tril_cpu_kernel.cc new file mode 100644 index 00000000000..09b2ef145b5 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/tril_cpu_kernel.cc @@ -0,0 +1,154 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/tril_cpu_kernel.h" + +#include +#include "Eigen/Core" + +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kTrilInputsNum = 1; +constexpr size_t kTrilOutputsNum = 1; +constexpr size_t kDim = 2; +} // namespace + +void TrilCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + + input_dims_ = input_shape_.size(); + if (input_dims_ < kDim) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got " + << input_dims_ << "-D."; + } + if (common::AnfAlgo::HasNodeAttr("diagonal", kernel_node)) { + diagonal_ = common::AnfAlgo::GetNodeAttr(kernel_node, "diagonal"); + } +} + +bool TrilCpuKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTrilInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTrilOutputsNum, kernel_name_); + + switch (dtype_) { + case (kNumberTypeUInt8): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeUInt16): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeUInt32): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeUInt64): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeInt8): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeInt16): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeInt32): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeInt64): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeFloat16): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeFloat32): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeFloat64): + LaunchKernel(inputs, outputs); + break; + case (kNumberTypeBool): + LaunchKernel(inputs, outputs); + break; + default: + MS_LOG(EXCEPTION) << "the datatype of the input not support, support datatype: " + "uint8, uint16, uint32, uint64, int8, int16, int32, int64, " + "float16, float32, float64, bool."; + } + return true; +} + +template +void TrilCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + size_t input_size = 1; + for (size_t i = 0; i < input_dims_; ++i) { + input_size *= input_shape_[i]; + } + + using MatrixMap = Eigen::Map>; + + auto matrix_width = input_shape_[input_dims_ - 2]; + auto matrix_height = input_shape_[input_dims_ - 1]; + auto matrix_size = matrix_width * matrix_height; + auto matrixs_num = input_size / matrix_size; + + for (size_t k = 0; k < matrixs_num; ++k) { + MatrixMap input(input_addr + k * matrix_size, matrix_width, matrix_height); + MatrixMap output(output_addr + k * matrix_size, matrix_width, matrix_height); + output = input.template triangularView(); + if (diagonal_ > 0) { + for (size_t i = 0; i < matrix_width; i++) { + for (size_t j = i + 1; j <= i + diagonal_ && j < matrix_height; j++) { + output(i, j) = input(i, j); + } + } + } else { + for (size_t j = 0; j < matrix_height; j++) { + for (size_t i = j; i < j - diagonal_ && i < matrix_width; i++) { + output(i, j) = static_cast(0.0); + } + } + } + } +} + +std::vector TrilCpuKernelMod::GetOpSupport() { + static std::vector kernel_attr_list = { + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool)}; + return kernel_attr_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Tril, TrilCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/tril_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/tril_cpu_kernel.h new file mode 100644 index 00000000000..ac2cbc07fb9 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/tril_cpu_kernel.h @@ -0,0 +1,53 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_TRIL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_TRIL_CPU_KERNEL_H_ + +#include +#include + +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class TrilCpuKernelMod : public NativeCpuKernelMod { + public: + TrilCpuKernelMod() = default; + ~TrilCpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + protected: + std::vector GetOpSupport() override; + + private: + int64_t diagonal_{0}; + std::vector input_shape_; + size_t input_dims_; + TypeId dtype_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_TRIL_CPU_KERNEL_H_ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 4f9d7ed9bf1..53bea37b776 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -132,6 +132,7 @@ constexpr auto kLstsq = "Lstsq"; constexpr auto kLowerBound = "LowerBound"; constexpr auto kUpperBound = "UpperBound"; constexpr auto kCummax = "Cummax"; +constexpr auto kTril = "Tril"; // NN constexpr auto kCTCLoss = "CTCLoss"; @@ -372,6 +373,7 @@ GVAR_DEF(PrimitivePtr, kPrimLstsq, std::make_shared(kLstsq)); GVAR_DEF(PrimitivePtr, kPrimLowerBound, std::make_shared(kLowerBound)); GVAR_DEF(PrimitivePtr, kPrimUpperBound, std::make_shared(kUpperBound)); GVAR_DEF(PrimitivePtr, kPrimCummax, std::make_shared(kCummax)); +GVAR_DEF(PrimitivePtr, kPrimTril, std::make_shared(kTril)); // image GVAR_DEF(PrimitivePtr, kPrimCropAndResizeGradBoxes, std::make_shared(kCropAndResizeGradBoxes)); diff --git a/mindspore/core/ops/tril.cc b/mindspore/core/ops/tril.cc new file mode 100644 index 00000000000..875a0937b1a --- /dev/null +++ b/mindspore/core/ops/tril.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/tril.h" + +#include +#include + +#include "abstract/primitive_infer_map.h" +#include "ops/op_utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr TrilInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + + const int64_t kShapeSize = 2; + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger("x's rank", x_shape.size(), kGreaterEqual, kShapeSize, prim_name); + + return std::make_shared(x_shape); +} + +TypePtr TrilInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + + MS_EXCEPTION_IF_NULL(input_args[0]); + auto x_type = input_args[0]->BuildType(); + std::set valid_x_types(common_valid_types); + (void)valid_x_types.emplace(kBool); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_x_types, prim_name); + return x_type; +} +} // namespace + +AbstractBasePtr TrilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t kInputNum = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name()); + + auto infer_type = TrilInferType(primitive, input_args); + auto infer_shape = TrilInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} + +MIND_API_BASE_IMPL(Tril, PrimitiveC, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(Tril, prim::kPrimTril, TrilInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/tril.h b/mindspore/core/ops/tril.h new file mode 100644 index 00000000000..aaebf4208da --- /dev/null +++ b/mindspore/core/ops/tril.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_TRIL_H_ +#define MINDSPORE_CORE_OPS_TRIL_H_ + +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameTril = "Tril"; +class MIND_API Tril : public BaseOperator { + public: + MIND_API_BASE_MEMBER(Tril); + Tril() : BaseOperator(kNameTril) { InitIOName({"x"}, {"y"}); } +}; + +abstract::AbstractBasePtr TrilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_TRIL_H_ diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py index 34636780ee2..3cdf1d6de7d 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py @@ -19,6 +19,7 @@ from ...common import dtype as mstype from .._grad.grad_math_ops import binop_grad_common from .._grad.grad_base import bprop_getters from ..composite.multitype_ops.zeros_like_impl import zeros_like +from ..operations.array_ops import Tril from .. import functional as F from .. import operations as P @@ -167,3 +168,16 @@ def get_bprop_extract_volume_patches(self): return (dx,) return bprop + + +@bprop_getters.register(Tril) +def get_bprop_tril(self): + """Grad definition for 'Tril' operation""" + diagonal = self.diagonal + tril = Tril(diagonal) + + def bprop(x, out, dout): + dx = tril(dout) + return (dx,) + + return bprop diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 2c055696f14..c0caf2061e5 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -133,3 +133,4 @@ from .priority_replay_buffer import _prb_create_op_cpu from .priority_replay_buffer import _prb_push_op_cpu from .priority_replay_buffer import _prb_sample_op_cpu from .priority_replay_buffer import _prb_update_op_cpu +from .tril import _tril_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/tril.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/tril.py new file mode 100644 index 00000000000..9fb5fdaf9e2 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/tril.py @@ -0,0 +1,42 @@ +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tril op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +tril_op_info = AiCPURegOp("Tril") \ + .fusion_type("OPAQUE") \ + .attr("diagonal", "int") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(tril_op_info) +def _tril_aicpu(): + """Tril AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 4b62bfdd84c..19bb375539a 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -7272,3 +7272,43 @@ class Cummax(Primitive): """Initialize Cummax""" validator.check_value_type("dim", dim, [int], self.name) self.init_prim_io_names(inputs=['x'], outputs=['y', 'indices']) + + +class Tril(Primitive): + """ + Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input, + the other elements of the result tensor out are set to 0. + The lower triangular part of the matrix is defined as the elements on and below the diagonal. + + Args: + diagonal (int): An optional attribute indicates the diagonal to consider, default to 0. + + Inputs: + - **x** (Tensor) - A Tensor with shape :math:`(x_1, x_2, ..., x_R)`. The rank must be at least 2. + Supporting all number types including bool. + + Outputs: + Tensor, the same shape and data type as the input. + + Raises: + TypeError: If `x` is not a Tensor. + TypeError: If `diagonal` is not an int. + TypeError: If the type of `x` is neither number nor bool. + ValueError: If the rank of `x` is less than 2. + + Supported Platforms: + ``CPU`` + + Examples: + >>> tril = ops.Tril() + >>> output = tril(Tensor(np.array([[-13.5383, 2.5474, ], [-5.7496, -3.4548]]), mindspore.float32)) + >>> print(output) + [[ -13.5383 0. ] + [ -5.7496 -3.4548]] + """ + + @prim_attr_register + def __init__(self, diagonal=0): + """Initialize Tril.""" + self.init_prim_io_names(inputs=["x"], outputs=["y"]) + validator.check_value_type("diagonal", diagonal, [int], self.name) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 335dc437fa3..96e1e5e8ec6 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -31,6 +31,7 @@ from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _quant_ops as Q from mindspore.ops.operations import nn_ops as nps +from mindspore.ops.operations.array_ops import Tril from mindspore.ops.operations.random_ops import NonDeterministicInts from mindspore.nn.layer import normalization from mindspore._c_expression import security @@ -2823,6 +2824,11 @@ test_case_array_ops = [ 'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])], 'skip': ['backward'], }), + ('Tril', { + 'block': Tril(), + 'desc_inputs': [Tensor(np.random.rand(3, 8, 9), mstype.float32)], + 'desc_brop': [Tensor(np.random.rand(5, 6, 6), mstype.float32)] + }), ] test_case_image_ops = [