From 8a8ac86b3643bb5b9432d6647ad7ddd0ccf5c565 Mon Sep 17 00:00:00 2001 From: zhuyuxiao Date: Thu, 8 Apr 2021 16:40:21 +0800 Subject: [PATCH] tensor add update: support int, uint operation on CPU kernel --- .../cpu/tensoradd_cpu_kernel.cc | 19 ++++++++++--------- .../cpu/tensoradd_cpu_kernel.h | 11 +++++++++-- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/tensoradd_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/tensoradd_cpu_kernel.cc index 7bc2b17a2b1..ae78bcb4198 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/tensoradd_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/tensoradd_cpu_kernel.cc @@ -19,8 +19,8 @@ namespace mindspore { namespace kernel { - -void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { +template +void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); // Init shape ans strides input_shape_a_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); @@ -28,13 +28,14 @@ void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); } -bool TensorAddCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input_addr_a = reinterpret_cast(inputs[0]->addr); - auto input_addr_b = reinterpret_cast(inputs[1]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto output_size = outputs[0]->size / sizeof(float); +template +bool TensorAddCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + T *input_addr_a = reinterpret_cast(inputs[0]->addr); + T *input_addr_b = reinterpret_cast(inputs[1]->addr); + T *output_addr = reinterpret_cast(outputs[0]->addr); + size_t output_size = outputs[0]->size / sizeof(T); if (input_shape_a_ == input_shape_b_) { auto task = [output_addr, input_addr_a, input_addr_b](size_t start, size_t end) { for (size_t i = start; i < end; ++i) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/tensoradd_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/tensoradd_cpu_kernel.h index cefd2d87571..2e6845baac1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/tensoradd_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/tensoradd_cpu_kernel.h @@ -23,6 +23,7 @@ namespace mindspore { namespace kernel { +template class TensorAddCPUKernel : public CPUKernel { public: TensorAddCPUKernel() = default; @@ -39,9 +40,15 @@ class TensorAddCPUKernel : public CPUKernel { std::vector output_shape_; }; -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( Add, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TensorAddCPUKernel); + TensorAddCPUKernel, float); +MS_REG_CPU_KERNEL_T( + Add, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + TensorAddCPUKernel, int); +MS_REG_CPU_KERNEL_T( + Add, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + TensorAddCPUKernel, uint32_t); } // namespace kernel } // namespace mindspore