From b5ec25cc49f090775c359e6a3d6f77a54bdc15d5 Mon Sep 17 00:00:00 2001 From: buxue Date: Tue, 18 May 2021 17:51:44 +0800 Subject: [PATCH] develop int32 AddN op for cpu --- .../cpu/mkldnn/addn_cpu_kernel.cc | 43 +++++++++--- .../cpu/mkldnn/addn_cpu_kernel.h | 8 ++- tests/st/ops/cpu/test_addn_op.py | 66 ++++++++++--------- 3 files changed, 75 insertions(+), 42 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc index f8c90c47f3a..4e7a745a022 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc @@ -17,14 +17,25 @@ #include "backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" +#include "backend/kernel_compiler/cpu/nnacl/fp32/add_fp32.h" +#include "backend/kernel_compiler/cpu/nnacl/errorcode.h" #include "utils/ms_utils.h" +#include "common/thread_pool.h" namespace mindspore { namespace kernel { +void AddInt(const int *in_0, const int *in_1, int *out, int start, int end) { + int ret = ElementAddInt(in_0 + start, in_1 + start, out + start, end - start); + if (ret != NNACL_OK) { + MS_LOG(EXCEPTION) << "Add failed."; + } +} + void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); CheckParam(kernel_node); + input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); std::vector src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); @@ -42,15 +53,31 @@ void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) { bool AddNCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { - SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - for (size_t index = 2; index < input_num_; ++index) { - SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr); - SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr); + if (dtype_ == kNumberTypeFloat32) { + SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr); SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); ExecutePrimitive(); + for (size_t index = 2; index < input_num_; ++index) { + SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + } + } else if (dtype_ == kNumberTypeInt32) { + size_t elements_num = outputs[0]->size / sizeof(int); + const auto input_0 = reinterpret_cast(inputs[0]->addr); + const auto input_1 = reinterpret_cast(inputs[1]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + auto task_0 = std::bind(AddInt, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2); + CPUKernelUtils::ParallelFor(task_0, elements_num); + for (size_t index = 2; index < input_num_; ++index) { + const auto input = reinterpret_cast(inputs[index]->addr); + auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2); + CPUKernelUtils::ParallelFor(task, elements_num); + } + } else { + MS_LOG(EXCEPTION) << "AddN only support float32 and int32, but got " << TypeIdToType(dtype_)->ToString(); } return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h index ce4218ca598..21547d6082f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h @@ -24,7 +24,7 @@ namespace mindspore { namespace kernel { class AddNCPUKernel : public MKLCPUKernel { public: - AddNCPUKernel() : input_num_(0) {} + AddNCPUKernel() = default; ~AddNCPUKernel() override = default; void InitKernel(const CNodePtr &kernel_node) override; @@ -34,13 +34,17 @@ class AddNCPUKernel : public MKLCPUKernel { private: void CheckParam(const CNodePtr &kernel_node); - size_t input_num_; + size_t input_num_{0}; std::vector output_shape_; + TypeId dtype_{kNumberTypeFloat32}; }; MS_REG_CPU_KERNEL(AddN, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), AddNCPUKernel); +MS_REG_CPU_KERNEL(AddN, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + AddNCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/cpu/test_addn_op.py b/tests/st/ops/cpu/test_addn_op.py index d8cffe09842..bd2a974f147 100644 --- a/tests/st/ops/cpu/test_addn_op.py +++ b/tests/st/ops/cpu/test_addn_op.py @@ -19,60 +19,62 @@ import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor -from mindspore.common import dtype as mstype from mindspore.ops import operations as P context.set_context(mode=context.GRAPH_MODE, device_target='CPU') -class Net2I(nn.Cell): + +class Net2Inputs(nn.Cell): def __init__(self): - super(Net2I, self).__init__() + super(Net2Inputs, self).__init__() self.addn = P.AddN() def construct(self, x, y): return self.addn((x, y)) + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_net_2Input(): - x = np.arange(2 * 3 * 2).reshape(2, 3, 2).astype(np.float32) - y = np.arange(2 * 3 * 2).reshape(2, 3, 2).astype(np.float32) - addn = Net2I() - output = addn(Tensor(x, mstype.float32), Tensor(y, mstype.float32)) - print("output:\n", output) - expect_result = [[[0., 2.], - [4., 6.], - [8., 10.]], - [[12., 14.], - [16., 18.], - [20., 22.]]] +def test_two_tensors_add(): + x = np.arange(2 * 3 * 2).reshape((2, 3, 2)) + y = np.arange(88, 2 * 3 * 2 + 88).reshape((2, 3, 2)) + addn_net = Net2Inputs() + dtypes = (np.int32, np.float32) + for dtype in dtypes: + output = addn_net(Tensor(x.astype(dtype)), Tensor(y.astype(dtype))) + expect_result = (x + y).astype(dtype) + assert output.asnumpy().dtype == expect_result.dtype + assert np.array_equal(output.asnumpy(), expect_result) - assert (output.asnumpy() == expect_result).all() -class Net3I(nn.Cell): +class Net4Inputs(nn.Cell): def __init__(self): - super(Net3I, self).__init__() + super(Net4Inputs, self).__init__() self.addn = P.AddN() - def construct(self, x, y, z): - return self.addn((x, y, z)) + def construct(self, x, y, m, n): + return self.addn((x, y, m, n)) + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_net_3Input(): - x = np.arange(2 * 3).reshape(2, 3).astype(np.float32) - y = np.arange(2 * 3).reshape(2, 3).astype(np.float32) - z = np.arange(2 * 3).reshape(2, 3).astype(np.float32) - addn = Net3I() - output = addn(Tensor(x, mstype.float32), Tensor(y, mstype.float32), Tensor(z, mstype.float32)) - print("output:\n", output) - expect_result = [[0., 3., 6.], - [9., 12., 15]] +def test_four_tensors_add(): + x = np.arange(2 * 3).reshape((2, 3)) + y = np.arange(1, 2 * 3 + 1).reshape((2, 3)) + m = np.arange(2, 2 * 3 + 2).reshape((2, 3)) + n = np.arange(3, 2 * 3 + 3).reshape((2, 3)) + addn_net = Net4Inputs() + dtypes = (np.int32, np.float32) + for dtype in dtypes: + output = addn_net(Tensor(x.astype(dtype)), Tensor(y.astype(dtype)), + Tensor(m.astype(dtype)), Tensor(n.astype(dtype))) + expect_result = (x + y + m + n).astype(dtype) + assert output.asnumpy().dtype == expect_result.dtype + assert np.array_equal(output.asnumpy(), expect_result) - assert (output.asnumpy() == expect_result).all() if __name__ == '__main__': - test_net_2Input() - test_net_3Input() + test_two_tensors_add() + test_four_tensors_add()