diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.cc new file mode 100644 index 00000000000..9aa7c87b323 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/bucketize_cpu_kernel.h" +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +const size_t kOutputNum = 1; +const size_t kInputNum = 1; +const size_t kParallelDataNumSameShape = 64 * 1024; +const size_t kParallelDataNumSameShapeMid = 35 * 1024; +} // namespace + +void BucketizeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0); + boundaries_ = common::AnfAlgo::GetNodeAttr>(kernel_node, "boundaries"); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); +} + +bool BucketizeCpuKernelMod::Launch(const std::vector &inputs, + const std::vector & /* workspace */, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); + if (dtype_ != kNumberTypeInt32 && dtype_ != kNumberTypeInt64 && dtype_ != kNumberTypeFloat32 && + dtype_ != kNumberTypeFloat64) { + MS_LOG(EXCEPTION) << "Input data type must int32 or int64 or float32 or float64, but got data type." << dtype_; + return false; + } + size_t input_sizes = input_shape_.size(); + size_t output_sizes = output_shape_.size(); + if (input_sizes != output_sizes) { + MS_LOG(EXCEPTION) << "The tensor shape of input need be same with output."; + return false; + } + // BucketizeCompute(inputs, outputs); + switch (dtype_) { + case kNumberTypeInt32: + return BucketizeCompute(inputs, outputs); + case kNumberTypeInt64: + return BucketizeCompute(inputs, outputs); + case kNumberTypeFloat32: + return BucketizeCompute(inputs, outputs); + case kNumberTypeFloat64: + return BucketizeCompute(inputs, outputs); + default: + MS_LOG(ERROR) << "Unsupported data type."; + } + return true; +} + +template +bool BucketizeCpuKernelMod::BucketizeCompute(const std::vector &inputs, + const std::vector &outputs) { + auto input_data = reinterpret_cast(inputs[0]->addr); + auto output_data = reinterpret_cast(outputs[0]->addr); + size_t data_num_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies()); + std::vector boundaries_data = boundaries_; + std::sort(boundaries_data.begin(), boundaries_data.end()); + if (data_num_ >= kParallelDataNumSameShape) { + auto sharder_bucketize = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]); + output_data[i] = first_bigger_it - boundaries_data.begin(); + } + }; + ParallelLaunchAutoSearch(sharder_bucketize, data_num_, this, ¶llel_search_info_); + } else { + for (size_t i = 0; i < data_num_; i++) { + auto first_bigger_it = std::upper_bound(boundaries_data.begin(), boundaries_data.end(), input_data[i]); + output_data[i] = first_bigger_it - boundaries_data.begin(); + } + } + return true; +} + +std::vector BucketizeCpuKernelMod::GetOpSupport() { + static std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32)}; + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Bucketize, BucketizeCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.h new file mode 100644 index 00000000000..9ab2ecb4999 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/bucketize_cpu_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BUCKETIZE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BUCKETIZE_CPU_KERNEL_H_ + +#include + +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class BucketizeCpuKernelMod : public NativeCpuKernelMod { + public: + BucketizeCpuKernelMod() = default; + ~BucketizeCpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + template + bool BucketizeCompute(const std::vector &inputs, const std::vector &outputs); + + protected: + std::vector GetOpSupport() override; + + private: + std::vector input_shape_; + std::vector output_shape_; + std::vector boundaries_; + TypeId dtype_{kTypeUnknown}; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index c43c85dc43f..2b8b273078e 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -743,6 +743,7 @@ GVAR_DEF(PrimitivePtr, kPrimInv, std::make_shared("Inv")); GVAR_DEF(PrimitivePtr, kPrimBitwiseOr, std::make_shared("BitwiseOr")); GVAR_DEF(PrimitivePtr, kPrimBitwiseAnd, std::make_shared("BitwiseAnd")); GVAR_DEF(PrimitivePtr, kPrimBitwiseXor, std::make_shared("BitwiseXor")); +GVAR_DEF(PrimitivePtr, kPrimBucketize, std::make_shared("Bucketize")); GVAR_DEF(PrimitivePtr, kPrimEinsum, std::make_shared("Einsum")); GVAR_DEF(PrimitivePtr, kPrimEinsumGrad, std::make_shared("EinsumGrad")); diff --git a/mindspore/core/ops/bucketize.cc b/mindspore/core/ops/bucketize.cc new file mode 100644 index 00000000000..6325e5d98ff --- /dev/null +++ b/mindspore/core/ops/bucketize.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/bucketize.h" + +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr BucketizeInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto x = input_args[0]->BuildShape(); + MS_EXCEPTION_IF_NULL(x); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto out_shape = x_shape; + return std::make_shared(out_shape); +} + +TypePtr BucketizeInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + MS_EXCEPTION_IF_NULL(input_args[0]); + auto x_type = input_args[0]->BuildType(); + (void)CheckAndConvertUtils::CheckTensorTypeValid("input", x_type, common_valid_types, prim_name); + return std::make_shared(kInt32); +} +} // namespace + +MIND_API_BASE_IMPL(Bucketize, PrimitiveC, BaseOperator); +AbstractBasePtr BucketizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const size_t input_num = 1; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name()); + auto infer_type = BucketizeInferType(primitive, input_args); + auto infer_shape = BucketizeInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(Bucketize, prim::kPrimBucketize, BucketizeInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/bucketize.h b/mindspore/core/ops/bucketize.h new file mode 100644 index 00000000000..52e67a0797e --- /dev/null +++ b/mindspore/core/ops/bucketize.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_BUCKETIZE_H_ +#define MINDSPORE_CORE_OPS_BUCKETIZE_H_ +#include +#include +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameBucketize = "Bucketize"; +/// \brief Bucketizes 'input' based on 'boundaries'. +/// Refer to Python API @ref mindspore.ops.Bucketize for more details. +class MIND_API Bucketize : public BaseOperator { + public: + /// \brief Constructor. + Bucketize() : BaseOperator(kNameBucketize) { InitIOName({"input"}, {"output"}); } + // /// \brief Destructor. + // ~Bucketize() = default; + MIND_API_BASE_MEMBER(Bucketize); +}; + +AbstractBasePtr BucketizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_BUCKETIZE_H_ diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index a95bd28fd4e..1195a994eb3 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -139,4 +139,5 @@ from .priority_replay_buffer import _prb_sample_op_cpu from .priority_replay_buffer import _prb_update_op_cpu from .right_shift import _right_shift_aicpu from .tril import _tril_aicpu +from .bucketize import _bucketize_aicpu from .triu import _triu_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/bucketize.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/bucketize.py new file mode 100644 index 00000000000..755a4a283a7 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/bucketize.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Bucketize op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +bucketize_op_info = AiCPURegOp("Bucketize") \ + .fusion_type("OPAQUE") \ + .attr("boundaries", "listFloat") \ + .input(0, "input", "required") \ + .output(0, "output", "required") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(bucketize_op_info) +def _bucketize_aicpu(): + """Bucketize aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index f042c09496b..0f5e0e4c606 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -1054,6 +1054,51 @@ class ReduceMin(_Reduce): """ +class Bucketize(Primitive): + """ + Bucketizes 'input' based on 'boundaries'. + + Args: + boundaries (list_float): A sorted list of floats gives the boundary of the buckets, and no default value. + + Inputs: + - **input** (Tensor) - A tensor containing the search value(s). + + Outputs: + Tensor, with the same shape as the input, and data type is int32. + + Raises: + TypeError: If `boundaries` is not a listFloat. + TypeError: If `input` is not a Tensor. + + Supported Platforms: + ``CPU`` + + Examples: + >>> class Bucketize(nn.Cell): + ... def __init__(self, boundaries): + ... super().__init__() + ... self.bucketize = op.Bucketize(boundaries=boundaries) + ... def construct(self, input): + ... return self.bucketize(input) + >>> input = Tensor(np.array([[3, 6, 9], [3, 6, 9]]).astype(np.int32)) + >>> boundaries = list(np.array([1., 3., 5., 7., 9.])) + >>> net = Bucketize(boundaries) + >>> output = net(input) + >>> print(output) + [[2 3 5] + [2 3 5]] + """ + + @prim_attr_register + def __init__(self, boundaries): + """Initialize Bucketize""" + validator.check_value_type("boundaries", boundaries, [list], self.name) + for index, one_boundaries in enumerate(boundaries): + validator.check_value_type('boundaries[%d]' % index, one_boundaries, [float], self.name) + self.init_prim_io_names(inputs=['input'], outputs=['output']) + + class ReduceProd(_Reduce): """ Reduces a dimension of a tensor by multiplying all elements in the dimension, by default. And also can diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index f4da9dcf029..5410cd953dc 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -30,6 +30,7 @@ from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _quant_ops as Q +from mindspore.ops.operations.math_ops import Bucketize from mindspore.ops.operations import nn_ops as nps from mindspore.ops.operations.array_ops import Tril from mindspore.ops.operations.random_ops import NonDeterministicInts @@ -1879,6 +1880,10 @@ test_case_math_ops = [ 'block': P.Real(), 'desc_inputs': [[2, 2]], 'skip': ['backward']}), + ('Bucketize', { + 'block': Bucketize(boundaries=[1., 3., 5., 7., 9.]), + 'desc_inputs': [Tensor(np.array([[-1, 6, 8], [3, 6, 9]]).astype(np.float))], + 'skip': ['backward']}), ] test_case_nn_ops = [