modify split cpu kernel

This commit is contained in:
shaoxiangdong 2021-04-30 17:04:05 +08:00
parent d48151ab1e
commit 77c83d18cf
2 changed files with 57 additions and 24 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "backend/kernel_compiler/cpu/split_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
@ -29,20 +31,19 @@ void SplitCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
template <typename T>
void SplitCPUKernel<T>::Reshape() {
input_size_ = 1;
dims_current_after_axis_ = 1;
dims_after_axis_ = 1;
axis_step_ = input_shape_[axis_] / output_num_;
param_ = new SplitParameter();
param_->num_split_ = output_num_;
param_->split_dim_ = axis_ >= 0 ? axis_ : input_shape_.size() + axis_;
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) {
input_size_ *= input_shape_[i];
if (i > axis_) {
dims_current_after_axis_ *= input_shape_[i];
dims_after_axis_ *= input_shape_[i];
}
if (i == axis_) {
dims_current_after_axis_ *= input_shape_[i];
}
param_->strides_[input_shape_.size() - 1] = 1;
for (int i = input_shape_.size() - 2; i >= 0; i--) {
param_->strides_[i] = param_->strides_[i + 1] * input_shape_[i + 1];
}
param_->split_sizes_ = new int[sizeof(int) * param_->num_split_];
int split_size = input_shape_[param_->split_dim_] / output_num_;
for (int i = 0; i < param_->num_split_; i++) {
param_->split_sizes_[i] = split_size;
}
}
@ -61,17 +62,42 @@ bool SplitCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
}
template <typename T>
void SplitCPUKernel<T>::LaunchSplit(const T *input, T **output, size_t size) {
void SplitCPUKernel<T>::LaunchSplit(T *input, T **output, size_t size) {
(void)std::transform(input_shape_.begin(), input_shape_.end(), std::back_inserter(input_shape_int_),
[](const int &value) { return static_cast<int>(value); });
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0;
size_t thread_num = size < block_size * max_thread_num ? std::ceil(size / block_size) : max_thread_num;
param_->split_count_ = size / (input_shape_[param_->split_dim_] * param_->strides_[param_->split_dim_]);
int num_unit = param_->split_count_ * param_->num_split_;
int thread_n_stride;
if (thread_num != 0) {
thread_n_stride = UP_DIV(num_unit, thread_num);
}
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
int num = i % dims_current_after_axis_ / dims_after_axis_;
int block = num / axis_step_;
int block_pos = i / dims_current_after_axis_ * axis_step_ * dims_after_axis_ +
num % axis_step_ * dims_after_axis_ + i % dims_after_axis_;
output[block][block_pos] = input[i];
}
int task_id = start / (size / thread_num);
int thread_offset = task_id * thread_n_stride;
int num_unit_thread = MSMIN(thread_n_stride, num_unit - task_id * thread_n_stride);
DoSplit(input, reinterpret_cast<void **>(output), &input_shape_int_[0], thread_offset, num_unit_thread, param_,
sizeof(T));
};
CPUKernelUtils::ParallelFor(task, size);
return;
}
template <typename T>
void SplitCPUKernel<T>::FreeTmpBuff() {
if (param_->split_sizes_ != nullptr) {
delete[] param_->split_sizes_;
param_->split_sizes_ = nullptr;
}
if (param_ != nullptr) {
delete param_;
param_ = nullptr;
}
return;
}
@ -86,6 +112,7 @@ void SplitCPUKernel<T>::LaunchKernel(const std::vector<AddressPtr> &inputs,
}
size_t size = static_cast<size_t>(inputs[0]->size / sizeof(T));
LaunchSplit(input, output, size);
FreeTmpBuff();
return;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@
#include <thread>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/kernel_compiler/cpu/nnacl/base/split_base.h"
namespace mindspore {
namespace kernel {
@ -42,7 +43,8 @@ class SplitCPUKernel : public CPUKernel {
private:
void CheckParam(const CNodePtr &kernel_node);
void Reshape();
void LaunchSplit(const T *input, T **output, size_t size);
void LaunchSplit(T *input, T **output, size_t size);
void FreeTmpBuff();
int64_t axis_;
int64_t output_num_;
int64_t axis_step_;
@ -53,7 +55,11 @@ class SplitCPUKernel : public CPUKernel {
std::vector<std::vector<size_t>> output_shape_list_;
std::vector<size_t> input_shape_;
std::vector<int> input_shape_int_;
TypeId dtype_{kTypeUnknown};
protected:
SplitParameter *param_ = nullptr;
};
MS_REG_CPU_KERNEL_T(