From f23ffa4ae11a9914a750f6429052071f1b7dbe33 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Tue, 27 Oct 2020 11:32:30 +0800 Subject: [PATCH] add cpu AssignAdd int32 and int64 --- .../cpu/arithmetic_cpu_kernel.cc | 19 +++++++++++++++++-- .../cpu/arithmetic_cpu_kernel.h | 8 ++++++++ .../backend/kernel_compiler/cpu/cpu_kernel.h | 2 +- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index 2eac4619117..6792b815d9b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -20,6 +20,14 @@ namespace mindspore { namespace kernel { +template +void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = input1[i] + input2[i]; + input1[i] = out[i]; + } +} + template void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { @@ -65,11 +73,16 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { operate_type_ = MUL; } else if (kernel_name == "Div") { operate_type_ = DIV; + } else if (kernel_name == prim::kPrimAssignAdd->name()) { + operate_type_ = ASSIGNADD; } input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); input_shape1_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + if (output_shape_.size() == 0) { + output_shape_.insert(output_shape_.begin(), 1); + } size_t l = input_shape0_.size(); for (size_t i = 0; i < output_shape_.size() - l; ++i) { input_shape0_.insert(input_shape0_.begin(), 1); @@ -138,8 +151,8 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector &inputs, co T *input2 = reinterpret_cast(inputs[1]->addr); T *output = reinterpret_cast(outputs[0]->addr); auto lens = outputs[0]->size / sizeof(T); - MS_LOG(INFO) << "lens=" << lens; - const size_t thread_num = 24; + size_t thread_num = lens < 128 * 24 ? std::ceil(lens / 128.0) : 24; + MS_LOG(INFO) << "lens=" << lens << "; use thread_num=" << thread_num; std::vector threads; threads.reserve(thread_num); size_t start = 0; @@ -154,6 +167,8 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector &inputs, co threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mul, this, input1, input2, output, start, end)); } else if (operate_type_ == DIV) { threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div, this, input1, input2, output, start, end)); + } else if (operate_type_ == ASSIGNADD) { + threads.emplace_back(std::thread(&ArithmeticCPUKernel::AssignAdd, this, input1, input2, output, start, end)); } start += once_compute_size; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h index 07d984528b5..3de579b5502 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h @@ -45,6 +45,8 @@ class ArithmeticCPUKernel : public CPUKernel { void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end); template void Div(const T *input1, const T *input2, T *out, size_t start, size_t end); + template + void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end); std::vector input_shape0_; std::vector input_shape1_; std::vector input_element_num0_; @@ -64,6 +66,12 @@ MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL( Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ArithmeticCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index f318eaa2cd3..da746f39d17 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -52,7 +52,7 @@ const char END[] = "end"; const char SIZE[] = "size"; const char USE_NESTEROV[] = "use_nesterov"; const char GROUP[] = "group"; -enum OperateType { ADD = 0, SUB, MUL, DIV, SQUARE, SQRT }; +enum OperateType { ADD = 0, SUB, MUL, DIV, SQUARE, SQRT, ASSIGNADD }; class CPUKernel : public kernel::KernelMod { public: