diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cummax_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cummax_cpu_kernel.cc new file mode 100644 index 00000000000..6cdf24a15fc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cummax_cpu_kernel.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/cummax_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +template +void CummaxCPUKernelMod::InitKernel(const CNodePtr &kernel_node) { + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + output1_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + output2_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 1); + dim_ = AnfAlgo::GetNodeAttr(kernel_node, "dim"); +} + +template +bool CummaxCPUKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + auto input_data_addr = reinterpret_cast(inputs[0]->addr); + auto output1_data_addr = reinterpret_cast(outputs[0]->addr); + auto output2_data_addr = reinterpret_cast(outputs[1]->addr); + + const size_t dims = input_shape_.size(); + if (dims == 0) { + MS_LOG(EXCEPTION) << "The value of `dims` can not be 0"; + } + dim_ = (dim_ + dims) % dims; + std::vector p{1}; + + for (int64_t i = (int64_t)input_shape_.size() - 1; i >= 0; i--) + p.push_back(p[(int64_t)input_shape_.size() - 1 - i] * input_shape_[i]); + reverse(p.begin(), p.end()); + + size_t input_stride = p[dim_ + 1]; + size_t output1_stride = p[dim_ + 1]; + size_t output2_stride = p[dim_ + 1]; + size_t input_dim_size = input_shape_[dim_]; + + int exit_ok = 0; + std::vector counter(dims, 0); + + while (!exit_ok) { + T out = input_data_addr[0]; + int idx = 0; + for (size_t i = 0; i < input_dim_size; i++) { + T cur = input_data_addr[i * input_stride]; + if (cur >= out) { + out = cur; + idx = i; + } + output1_data_addr[i * output1_stride] = out; + output2_data_addr[i * output2_stride] = idx; + } + + if (dims == 1) break; + for (size_t dim_i = 0; dim_i < dims; dim_i++) { + if (dim_i == dim_) { + if (dim_i == dims - 1) { + exit_ok = 1; + break; + } + continue; + } + counter[dim_i]++; + input_data_addr += p[dim_i + 1]; + output1_data_addr += p[dim_i + 1]; + output2_data_addr += p[dim_i + 1]; + + if (counter[dim_i] == input_shape_[dim_i]) { + if (dim_i == dims - 1) { + exit_ok = 1; + break; + } else { + input_data_addr -= counter[dim_i] * p[dim_i + 1]; + output1_data_addr -= counter[dim_i] * p[dim_i + 1]; + output2_data_addr -= counter[dim_i] * p[dim_i + 1]; + counter[dim_i] = 0; + } + } else { + break; + } + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cummax_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cummax_cpu_kernel.h new file mode 100644 index 00000000000..e277f84335e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cummax_cpu_kernel.h @@ -0,0 +1,69 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class CummaxCPUKernelMod : public NativeCpuKernelMod { + public: + CummaxCPUKernelMod() = default; + ~CummaxCPUKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override; + + private: + std::vector input_shape_; + std::vector output1_shape_; + std::vector output2_shape_; + size_t dim_; +}; + +MS_REG_CPU_KERNEL_T( + Cummax, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), + CummaxCPUKernelMod, float); +MS_REG_CPU_KERNEL_T( + Cummax, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), + CummaxCPUKernelMod, float16); +MS_REG_CPU_KERNEL_T( + Cummax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + CummaxCPUKernelMod, int32_t); +MS_REG_CPU_KERNEL_T( + Cummax, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + CummaxCPUKernelMod, int64_t); +MS_REG_CPU_KERNEL_T( + Cummax, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), + CummaxCPUKernelMod, int8_t); +MS_REG_CPU_KERNEL_T( + Cummax, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), + CummaxCPUKernelMod, uint8_t); +MS_REG_CPU_KERNEL_T( + Cummax, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), + CummaxCPUKernelMod, uint32_t); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index db3dd25d248..f3caed483e2 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -119,6 +119,7 @@ constexpr auto kReshape = "Reshape"; constexpr auto kLstsq = "Lstsq"; constexpr auto kLowerBound = "LowerBound"; constexpr auto kUpperBound = "UpperBound"; +constexpr auto kCummax = "Cummax"; // NN constexpr auto kCTCLoss = "CTCLoss"; @@ -350,6 +351,7 @@ MS_CORE_API inline const PrimitivePtr kPrimExtractVolumePatches = std::make_shar MS_CORE_API inline const PrimitivePtr kPrimLstsq = std::make_shared(kLstsq); MS_CORE_API inline const PrimitivePtr kPrimLowerBound = std::make_shared(kLowerBound); MS_CORE_API inline const PrimitivePtr kPrimUpperBound = std::make_shared(kUpperBound); +MS_CORE_API inline const PrimitivePtr kPrimCummax = std::make_shared(kCummax); // NN MS_CORE_API inline const PrimitivePtr kPrimCeLU = std::make_shared("CeLU"); diff --git a/mindspore/core/ops/cummax.cc b/mindspore/core/ops/cummax.cc new file mode 100644 index 00000000000..ddb3110cf77 --- /dev/null +++ b/mindspore/core/ops/cummax.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/cummax.h" +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + auto x_shape = input_args[0]->BuildShape(); + auto x_shape_value = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape)[kShape]; + auto dim = GetValue(primitive->GetAttr("dim")); + if (x_shape_value.size() <= 0) { + MS_EXCEPTION(ValueError) << "Inputs should not be a " << x_shape_value.size() << " dimensional tensor."; + } + if (dim >= static_cast(x_shape_value.size()) || dim < -static_cast(x_shape_value.size())) { + MS_EXCEPTION(ValueError) << "The value of `dim` should be in the range of [" + << -static_cast(x_shape_value.size()) << "," + << static_cast(x_shape_value.size()) << ")"; + } + return std::make_shared(std::vector{x_shape, x_shape}); +} + +TuplePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { + auto op_name = primitive->name(); + const std::set valid_types = {kInt8, kInt32, kInt64, kUInt8, kUInt32, kFloat16, kFloat32}; + auto y_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, op_name); + auto indices_type = kInt64; + return std::make_shared(std::vector{y_type, indices_type}); +} + +AbstractBasePtr CummaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + auto types = InferType(primitive, input_args); + auto shapes = InferShape(primitive, input_args); + return abstract::MakeAbstract(shapes, types); +} + +REGISTER_PRIMITIVE_EVAL_IMPL(Cummax, prim::kPrimCummax, CummaxInfer, nullptr, true); +} // namespace +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/cummax.h b/mindspore/core/ops/cummax.h new file mode 100644 index 00000000000..1e59f6d68b0 --- /dev/null +++ b/mindspore/core/ops/cummax.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CUMMAX_H_ +#define MINDSPORE_CORE_OPS_CUMMAX_H_ +#include +#include + +#include "ops/primitive_c.h" +#include "ops/op_utils.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCummax = "Cummax"; +class Cummax : public PrimitiveC { + public: + Cummax() : PrimitiveC(kNameCummax) { InitIOName({"x"}, {"y", "indices"}); } + ~Cummax() = default; + MS_DECLARE_PARENT(Cummax, PrimitiveC); +}; +AbstractBasePtr CummaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimCummaxPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_CUMMAX_H_ diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 488e3374093..4207eb038ce 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -103,3 +103,4 @@ from .upper_bound import _upper_bound_aicpu from .grid_sampler_3d import _grid_sampler_3d_aicpu from .grid_sampler_3d_grad import _grid_sampler_3d_grad_aicpu from .cross import _cross_aicpu +from .cummax import _cummax_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/cummax.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/cummax.py new file mode 100644 index 00000000000..4f6fb6ca155 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/cummax.py @@ -0,0 +1,37 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Cummax op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +cummax_op_info = AiCPURegOp("Cummax") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .output(1, "indices", "required") \ + .attr("dim", "int") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I64_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I64_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.I64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I64_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I64_Default) \ + .get_op_info() + + +@op_info_register(cummax_op_info) +def _cummax_aicpu(): + """Cummax AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index 10d2f9486a6..c7721ea5862 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -35,9 +35,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted, TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches, - LowerBound, UpperBound) -from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, - Broadcast, + LowerBound, UpperBound, Cummax) +from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, Broadcast, _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad, _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather) @@ -517,6 +516,7 @@ __all__ = [ "Custom", "LuSolve", "CholeskyInverse", + "Cummax", ] __sponge__ = [ diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index b10300994f7..e5dd5d5f24b 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -7017,3 +7017,63 @@ class UpperBound(Primitive): valid_values = (mstype.int32, mstype.int64) validator.check_type_name("out_type", out_type, valid_values, self.name) self.init_prim_io_names(inputs=['sorted_x', 'values'], outputs=['y']) + + +class Cummax(Primitive): + """ + Computes the cumulative max and indice of input tensor along dim.Returns a tuple (values,indices) where 'values' + is the cumulative maximum value of input elements in the dimension 'dim'and 'indices' is the index position for + each maximum value. + + .. warning:: + This is an experimental prototype that is subject to change and/or deletion. + + .. math:: + + y_i = max(x_1 , x_2 , x_3 ,... ,x_i) + + Args: + dim (int): The dim to accumulate the tensor's value. Must be in the range [-rank(input), rank(input)). + The default value is -1. + + Inputs: + - **input** (Tensor) - The input tensor whose dtype is int8, int32, int64, uint8, uint32, float16, float32. + + Outputs: + - **values** (Tensor), the shape of the output tensor is consistent with the input tensor's. + - **indices** (Tensor), the shape of the output tensor is consistent with the input tensor's. + + Raises: + TypeError: If `input` is not a Tensor. + TypeError: If `dim` is not an int. + ValueError: If `dim` is out of range, `dim` should be [-len(input.shape), len(input.shape)-1]. + + Supported Platforms: + ``CPU`` + + Examples: + >>> import mindspore + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.ops as ops + >>> cummax = ops.Cummax(dim=0) + >>> x = Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float32)) + >>> output = cummax(x) + >>> print(output) + values: + [[ 3. 4. 6. 10.] + [ 3. 6. 7. 10.] + [ 4. 6. 8. 10.] + [ 4. 6. 8. 10.]] + indices: + [[0 0 0 0] + [0 1 1 0] + [2 1 2 0] + [2 1 2 0]] + """ + + @prim_attr_register + def __init__(self, dim=-1): + """Initialize Cummax""" + validator.check_value_type("dim", dim, [int], self.name) + self.init_prim_io_names(inputs=['x'], outputs=['y', 'indices']) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index ef0ae07c0b4..acd83e9d0ef 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -2803,6 +2803,11 @@ test_case_array_ops = [ Tensor([[3], [6], [7], [8]], mstype.int8)], 'skip': ['backward'], }), + ('Cummax', { + 'block': P.Cummax(dim=-1), + 'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])], + 'skip': ['backward'], + }), ] test_case_image_ops = [