From 2cf850b45ac09cfdfa9bd858ce2a19b1189b9f51 Mon Sep 17 00:00:00 2001 From: shenwei41 Date: Wed, 8 Mar 2023 03:19:51 +0800 Subject: [PATCH] Add Size cpu operation --- .../device/cpu/kernel/size_cpu_kernel.cc | 101 ++++++++++++++++++ .../device/cpu/kernel/size_cpu_kernel.h | 51 +++++++++ mindspore/core/ops/size.cc | 61 ++++++++++- .../mindspore/ops/operations/array_ops.py | 16 +-- tests/st/ops/cpu/test_size_op.py | 87 +++++++++++++++ 5 files changed, 298 insertions(+), 18 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/size_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/size_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_size_op.py diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/size_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/size_cpu_kernel.cc new file mode 100644 index 00000000000..23359a3b788 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/size_cpu_kernel.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/cpu/kernel/size_cpu_kernel.h" +#include +#include +#include +#include +#include +#include +#include "include/common/thread_pool.h" + +namespace mindspore { +namespace kernel { +namespace { +const size_t kSizeInputsNum = 1; +const size_t kSizeOutputsNum = 1; +}; // namespace +bool SizeCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + kernel_name_ = base_operator->name(); + auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); + auto is_match = MatchKernelAttr(tensor_attr, GetOpSupport()).first; + if (!is_match) { + MS_LOG_ERROR << "Can not match kernel based on given attr!"; + return false; + } + + if (Resize(base_operator, inputs, outputs) == KRET_RESIZE_FAILED) { + MS_LOG_ERROR << "Resize failed!"; + return false; + } + return true; +} + +int SizeCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) { + MS_EXCEPTION_IF_NULL(base_operator); + if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } + auto shape_vector = inputs[kIndex0]->GetShapeVector(); + int64_t elements = 1; + for (size_t i = 0; i < shape_vector.size(); i++) { + elements *= shape_vector[i]; + } + input_elements = elements; + return KRET_OK; +} + +bool SizeCpuKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSizeInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSizeOutputsNum, kernel_name_); + auto output_data = reinterpret_cast(outputs[kIndex0]->addr); + MS_EXCEPTION_IF_NULL(output_data); + output_data[kIndex0] = input_elements; + return true; +} + +std::vector SizeCpuKernelMod::GetOpSupport() { + return { + KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeComplex).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt4).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeGLUInt).AddOutputAttr(kNumberTypeInt32), + }; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Size, SizeCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/size_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/size_cpu_kernel.h new file mode 100644 index 00000000000..d9682b375e3 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/size_cpu_kernel.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIZE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIZE_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/cpu/kernel/nnacl/op_base.h" + +namespace mindspore { +namespace kernel { +class SizeCpuKernelMod : public NativeCpuKernelMod { + public: + SizeCpuKernelMod() = default; + ~SizeCpuKernelMod() override = default; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &others = std::map()) override; + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + int32_t input_elements; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIZE_CPU_KERNEL_H_ diff --git a/mindspore/core/ops/size.cc b/mindspore/core/ops/size.cc index 7836b91e6af..45a2bc7565e 100644 --- a/mindspore/core/ops/size.cc +++ b/mindspore/core/ops/size.cc @@ -15,13 +15,68 @@ */ #include "ops/size.h" -#include "ops/primitive_c.h" -#include "utils/log_adapter.h" +#include "utils/check_convert_utils.h" +#include "ops/op_utils.h" #include "mindapi/src/helper.h" namespace mindspore { namespace ops { +namespace { +constexpr int64_t input_num = 1; +} // namespace +class SizeInfer : public abstract::OpInferBase { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(primitive); + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + return abstract::kNoShape; + } + + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(primitive); + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + TypePtr res = kInt64; + return res; + } + + ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector &input_args) const { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + auto input_type = input_args[0]->BuildType(); + MS_EXCEPTION_IF_NULL(input_type); + if (!input_type->isa()) { + MS_EXCEPTION(TypeError) << "For '" << prim_name + << "', input must be a Tensor, but got: " << input_type->ToString() << "."; + } + auto input_shape_ptr = input_args[0]->BuildShape(); + MS_EXCEPTION_IF_NULL(input_shape_ptr); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_shape_ptr); + auto input_shape = shape_map[kShape]; + if (IsDynamicRank(input_shape) || IsDynamicShape(input_shape)) { + return kAnyValue; + } + size_t elements = 1; + for (size_t i = 0; i < input_shape.size(); i++) { + elements *= input_shape[i]; + } + auto elements_value = SizeToLong(elements); + ValuePtr res = MakeValue(elements_value); + return res; + } + + AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive, + const std::vector &input_args) const override { + auto type = InferType(primitive, input_args); + auto shape = InferShape(primitive, input_args); + auto value = InferValue(primitive, input_args); + auto res = MakeAbstract(shape, type); + res->set_value(value); + return res; + } +}; MIND_API_OPERATOR_IMPL(Size, BaseOperator); -REGISTER_PRIMITIVE_C(kNameSize, Size); +REGISTER_PRIMITIVE_OP_INFER_IMPL(Size, prim::kPrimSize, SizeInfer, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 942b66962f5..ced063fa47c 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -15,7 +15,6 @@ """Operators for array.""" import copy -import functools import itertools import numbers @@ -1204,7 +1203,7 @@ class Rank(PrimitiveWithInfer): return len(x.shape) -class Size(PrimitiveWithInfer): +class Size(Primitive): r""" Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the Tensor. @@ -1226,19 +1225,6 @@ class Size(PrimitiveWithInfer): def __init__(self): """Initialize Size""" - def __infer__(self, x): - size = 1 - validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) - shp = x['shape'] - if not shp: - size = 0 - else: - size = functools.reduce(lambda x, y: x * y, x['shape']) - out = {'shape': None, - 'dtype': mstype.int64, - 'value': size} - return out - class MatrixDiagV3(Primitive): """ diff --git a/tests/st/ops/cpu/test_size_op.py b/tests/st/ops/cpu/test_size_op.py new file mode 100644 index 00000000000..5f275fb383f --- /dev/null +++ b/tests/st/ops/cpu/test_size_op.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class Net(nn.Cell): + + def __init__(self): + super(Net, self).__init__() + self.ops = P.Size() + + def construct(self, x): + return self.ops(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_size_1_dimension(mode): + """ + Feature: test pynative mode and graph mode + Description: Test 1-D Tensor + Expectation: the result match to expected value + """ + np_array = np.array([2, 3, 4]).astype(np.int32) + input_x = Tensor(np_array) + expect = 3 + net = Net() + out = net(input_x) + assert out == expect + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_size_2_dimension(mode): + """ + Feature: test pynative mode and graph mode + Description: Test 2-D Tensor + Expectation: the result match to expected value + """ + np_array = np.array([[2, 2], [2, 2], [3, 3]]).astype(np.int32) + input_x = Tensor(np_array) + expect = 6 + net = Net() + out = net(input_x) + assert out == expect + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_size_3_dimension(mode): + """ + Feature: test pynative mode and graph mode + Description: Test 3-D Tensor + Expectation: the result match to expected value + """ + np_array = np.array([[[1, 1], [2, 2]], [[3, 3], [4, 4]], [[5, 5], [6, 6]]]).astype(np.int32) + input_x = Tensor(np_array) + expect = 12 + net = Net() + out = net(input_x) + assert out == expect