!7809 add cpu AssignAdd int32 and int64

Merge pull request !7809 from zhaoting/assingnadd
This commit is contained in:
mindspore-ci-bot 2020-10-27 23:03:49 +08:00 committed by Gitee
commit 8265f8deae
3 changed files with 26 additions and 3 deletions

View File

@ -20,6 +20,14 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T>
void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = input1[i] + input2[i];
input1[i] = out[i];
}
}
template <typename T> template <typename T>
void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t start, size_t end) { void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) { for (size_t i = start; i < end; i++) {
@ -65,11 +73,16 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = MUL; operate_type_ = MUL;
} else if (kernel_name == "Div") { } else if (kernel_name == "Div") {
operate_type_ = DIV; operate_type_ = DIV;
} else if (kernel_name == prim::kPrimAssignAdd->name()) {
operate_type_ = ASSIGNADD;
} }
input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_shape1_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); input_shape1_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (output_shape_.size() == 0) {
output_shape_.insert(output_shape_.begin(), 1);
}
size_t l = input_shape0_.size(); size_t l = input_shape0_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) { for (size_t i = 0; i < output_shape_.size() - l; ++i) {
input_shape0_.insert(input_shape0_.begin(), 1); input_shape0_.insert(input_shape0_.begin(), 1);
@ -138,8 +151,8 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, co
T *input2 = reinterpret_cast<T *>(inputs[1]->addr); T *input2 = reinterpret_cast<T *>(inputs[1]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr); T *output = reinterpret_cast<T *>(outputs[0]->addr);
auto lens = outputs[0]->size / sizeof(T); auto lens = outputs[0]->size / sizeof(T);
MS_LOG(INFO) << "lens=" << lens; size_t thread_num = lens < 128 * 24 ? std::ceil(lens / 128.0) : 24;
const size_t thread_num = 24; MS_LOG(INFO) << "lens=" << lens << "; use thread_num=" << thread_num;
std::vector<std::thread> threads; std::vector<std::thread> threads;
threads.reserve(thread_num); threads.reserve(thread_num);
size_t start = 0; size_t start = 0;
@ -154,6 +167,8 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, co
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mul<T>, this, input1, input2, output, start, end)); threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mul<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == DIV) { } else if (operate_type_ == DIV) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div<T>, this, input1, input2, output, start, end)); threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ASSIGNADD) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::AssignAdd<T>, this, input1, input2, output, start, end));
} }
start += once_compute_size; start += once_compute_size;
} }

View File

@ -45,6 +45,8 @@ class ArithmeticCPUKernel : public CPUKernel {
void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end); void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T> template <typename T>
void Div(const T *input1, const T *input2, T *out, size_t start, size_t end); void Div(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end);
std::vector<size_t> input_shape0_; std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_; std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_; std::vector<size_t> input_element_num0_;
@ -64,6 +66,12 @@ MS_REG_CPU_KERNEL(
MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL(
Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel); ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -52,7 +52,7 @@ const char END[] = "end";
const char SIZE[] = "size"; const char SIZE[] = "size";
const char USE_NESTEROV[] = "use_nesterov"; const char USE_NESTEROV[] = "use_nesterov";
const char GROUP[] = "group"; const char GROUP[] = "group";
enum OperateType { ADD = 0, SUB, MUL, DIV, SQUARE, SQRT }; enum OperateType { ADD = 0, SUB, MUL, DIV, SQUARE, SQRT, ASSIGNADD };
class CPUKernel : public kernel::KernelMod { class CPUKernel : public kernel::KernelMod {
public: public: