bias_add opt
This commit is contained in:
parent
6bda22b90a
commit
0d7411b345
|
@ -0,0 +1,142 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/bias_add.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
void BiasAddByInnerCore(const float *input, const float *bias, float *output, int64_t num) {
|
||||
int64_t index = 0;
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
for (; index <= num - C4NUM; index += C4NUM) {
|
||||
MS_FLOAT32X4 input_data = MS_LDQ_F32(input + index);
|
||||
MS_FLOAT32X4 bias_data = MS_LDQ_F32(bias + index);
|
||||
MS_STQ_F32(output + index, MS_ADD128_F32(input_data, bias_data));
|
||||
}
|
||||
#endif
|
||||
|
||||
for (; index < num; ++index) {
|
||||
output[index] = input[index] + bias[index];
|
||||
}
|
||||
}
|
||||
|
||||
void BiasAddByBatchCore(const float *input, const float *bias, float *output, int64_t num) {
|
||||
float *output1 = output;
|
||||
float *output2 = output + num;
|
||||
float *output3 = output + num * 2;
|
||||
float *output4 = output + num * 3;
|
||||
int64_t index = 0;
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
||||
for (; index <= num - C4NUM; index += C4NUM) {
|
||||
MS_LOAD128X4_F32(input_data, input + index, num);
|
||||
MS_FLOAT32X4 bias_data = MS_LDQ_F32(bias + index);
|
||||
MS_STQ_F32(output1 + index, MS_ADD128_F32(input_data1, bias_data));
|
||||
MS_STQ_F32(output2 + index, MS_ADD128_F32(input_data2, bias_data));
|
||||
MS_STQ_F32(output3 + index, MS_ADD128_F32(input_data3, bias_data));
|
||||
MS_STQ_F32(output4 + index, MS_ADD128_F32(input_data4, bias_data));
|
||||
}
|
||||
#endif
|
||||
const float *input_data1 = input;
|
||||
const float *input_data2 = input + num;
|
||||
const float *input_data3 = input + num * 2;
|
||||
const float *input_data4 = input + num * 3;
|
||||
for (; index < num; ++index) {
|
||||
output1[index] = input_data1[index] + bias[index];
|
||||
output2[index] = input_data2[index] + bias[index];
|
||||
output3[index] = input_data3[index] + bias[index];
|
||||
output4[index] = input_data4[index] + bias[index];
|
||||
}
|
||||
}
|
||||
|
||||
void DoBiasAddByBatch(const float *input, const float *bias, float *output, int64_t start, int64_t end,
|
||||
int64_t inner_num) {
|
||||
if (inner_num == 0) {
|
||||
return;
|
||||
}
|
||||
int64_t start_outer = start / inner_num;
|
||||
int64_t start_inner = start % inner_num;
|
||||
int64_t end_outer = end / inner_num;
|
||||
int64_t end_inner = end % inner_num;
|
||||
const float *cur_input = input + start;
|
||||
const float *cur_bias = bias + start_inner;
|
||||
float *cur_output = output + start;
|
||||
if (start_outer == end_outer) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner - start_inner);
|
||||
return;
|
||||
}
|
||||
if (start_inner != 0) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num - start_inner);
|
||||
start_outer += 1;
|
||||
cur_input += inner_num - start_inner;
|
||||
cur_bias = bias;
|
||||
cur_output += inner_num - start_inner;
|
||||
}
|
||||
int64_t step = C4NUM * inner_num;
|
||||
for (; start_outer <= end_outer - C4NUM; start_outer += C4NUM) {
|
||||
BiasAddByBatchCore(cur_input, cur_bias, cur_output, inner_num);
|
||||
cur_input += step;
|
||||
cur_output += step;
|
||||
}
|
||||
for (; start_outer < end_outer; ++start_outer) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num);
|
||||
cur_input += inner_num;
|
||||
cur_output += inner_num;
|
||||
}
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner);
|
||||
}
|
||||
|
||||
void DoBiasAddByInner(const float *input, const float *bias, float *output, int64_t start, int64_t end,
|
||||
int64_t inner_num) {
|
||||
if (inner_num == 0) {
|
||||
return;
|
||||
}
|
||||
int64_t start_outer = start / inner_num;
|
||||
int64_t start_inner = start % inner_num;
|
||||
int64_t end_outer = end / inner_num;
|
||||
int64_t end_inner = end % inner_num;
|
||||
const float *cur_input = input + start;
|
||||
const float *cur_bias = bias + start_inner;
|
||||
float *cur_output = output + start;
|
||||
if (start_outer == end_outer) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner - start_inner);
|
||||
return;
|
||||
} else {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num - start_inner);
|
||||
start_outer += 1;
|
||||
cur_input += inner_num - start_inner;
|
||||
cur_bias = bias;
|
||||
cur_output += inner_num - start_inner;
|
||||
}
|
||||
if (start_outer == end_outer) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner);
|
||||
return;
|
||||
} else {
|
||||
for (; start_outer < end_outer; ++start_outer) {
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, inner_num);
|
||||
cur_input += inner_num;
|
||||
cur_output += inner_num;
|
||||
}
|
||||
}
|
||||
BiasAddByInnerCore(cur_input, cur_bias, cur_output, end_inner);
|
||||
}
|
||||
|
||||
void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num,
|
||||
bool batch_priority) {
|
||||
if (batch_priority) {
|
||||
DoBiasAddByBatch(input, bias, output, start, end, inner_num);
|
||||
} else {
|
||||
DoBiasAddByInner(input, bias, output, start, end, inner_num);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_FP32_BIAS_ADD_H_
|
||||
#define MINDSPORE_NNACL_FP32_BIAS_ADD_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num,
|
||||
bool batch_priority);
|
||||
|
||||
#ifdef __cplusplus
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP32_BIAS_ADD_H_
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "src/runtime/kernel/arm/fp32/bias_fp32.h"
|
||||
#include <vector>
|
||||
#include "nnacl/fp32/bias_add.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -27,39 +28,13 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_BiasAdd;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int BiasCPUKernel::ReSize() {
|
||||
auto dims = in_tensors_.at(0)->shape();
|
||||
bias_param_->ndim_ = dims.size();
|
||||
if (bias_param_->ndim_ < 1 || bias_param_->ndim_ > 5) {
|
||||
MS_LOG(ERROR) << "input shape is invalid";
|
||||
return RET_ERROR;
|
||||
int BiasAddRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto kernel = reinterpret_cast<BiasCPUKernel *>(cdata);
|
||||
auto ret = kernel->DoExecute(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "BatchnormRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
}
|
||||
for (size_t i = 0; i < bias_param_->ndim_; i++) {
|
||||
bias_param_->in_shape0_[i] = dims[i];
|
||||
bias_param_->in_shape1_[i] = 1;
|
||||
bias_param_->out_shape_[i] = dims[i];
|
||||
}
|
||||
bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1];
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BiasCPUKernel::Run() {
|
||||
auto in = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||
auto bias = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
|
||||
auto out = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||
size_t data_size = static_cast<size_t>(in_tensors_.at(0)->ElementsNum());
|
||||
CHECK_NULL_RETURN(ms_context_->allocator);
|
||||
float *tile_in = reinterpret_cast<float *>(ms_context_->allocator->Malloc(data_size * sizeof(float)));
|
||||
float *tile_bias = reinterpret_cast<float *>(ms_context_->allocator->Malloc(data_size * sizeof(float)));
|
||||
if (tile_in == nullptr || tile_bias == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory allocation failed";
|
||||
ms_context_->allocator->Free(tile_in);
|
||||
ms_context_->allocator->Free(tile_bias);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = BroadcastAdd(in, bias, tile_in, tile_bias, out, static_cast<int>(data_size), bias_param_);
|
||||
ms_context_->allocator->Free(tile_in);
|
||||
ms_context_->allocator->Free(tile_bias);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -73,5 +48,79 @@ int BiasCPUKernel::Prepare() {
|
|||
return ReSize();
|
||||
}
|
||||
|
||||
int BiasCPUKernel::ReSize() {
|
||||
auto in_dims = in_tensors_.at(0)->shape();
|
||||
auto bias_dims = in_tensors_.at(1)->shape();
|
||||
if (bias_dims.empty() || in_dims.empty() || in_dims.size() < bias_dims.size()) {
|
||||
MS_LOG(ERROR) << "inTensors' shape are invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t dim_offset = in_dims.size() - bias_dims.size();
|
||||
inner_num_ = 1;
|
||||
for (size_t i = 0; i < bias_dims.size(); ++i) {
|
||||
if (in_dims[i + dim_offset] != bias_dims[i]) {
|
||||
MS_LOG(ERROR) << "inTensors' shape cannot match.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(bias_dims[i], inner_num_), RET_ERROR, "mul overflow.");
|
||||
inner_num_ *= bias_dims[i];
|
||||
}
|
||||
outer_num_ = 1;
|
||||
for (size_t i = 0; i < dim_offset; ++i) {
|
||||
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(in_dims[i], outer_num_), RET_ERROR, "mul overflow.");
|
||||
outer_num_ *= in_dims[i];
|
||||
}
|
||||
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(inner_num_, outer_num_), RET_ERROR, "mul overflow.");
|
||||
total_num_ = inner_num_ * outer_num_;
|
||||
GetThreadSegmentInfos();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void BiasCPUKernel::GetThreadSegmentInfos() {
|
||||
split_start_points_ = std::vector<int64_t>(op_parameter_->thread_num_, 0);
|
||||
split_end_points_ = std::vector<int64_t>(op_parameter_->thread_num_, 0);
|
||||
int64_t step = MSMAX(total_num_ / op_parameter_->thread_num_, C128NUM);
|
||||
int64_t remain_data = MSMAX(total_num_ - step * op_parameter_->thread_num_, 0);
|
||||
for (int i = 0; i < op_parameter_->thread_num_; ++i) {
|
||||
if (i == 0) {
|
||||
split_end_points_[i] = MSMIN(step, total_num_) + (i < remain_data ? 1 : 0);
|
||||
continue;
|
||||
}
|
||||
split_start_points_[i] = split_end_points_[i - 1];
|
||||
if (split_start_points_[i] >= total_num_) {
|
||||
split_start_points_[i] = 0;
|
||||
break;
|
||||
}
|
||||
split_end_points_[i] =
|
||||
split_start_points_[i] + MSMIN(step, total_num_ - split_start_points_[i]) + (i < remain_data ? 1 : 0);
|
||||
}
|
||||
MS_ASSERT(inner_num_ != 0);
|
||||
if (inner_num_ >= C64NUM && step / inner_num_ >= C6NUM) {
|
||||
batch_priority_ = true;
|
||||
} else {
|
||||
batch_priority_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
int BiasCPUKernel::Run() {
|
||||
auto ret = ParallelLaunch(this->ms_context_, BiasAddRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "BiasAddRun error error_code[" << ret << "]";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int BiasCPUKernel::DoExecute(int task_id) {
|
||||
auto input = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
|
||||
auto bias = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
|
||||
auto output = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||
if (split_start_points_[task_id] == split_end_points_[task_id]) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
BiasAddOpt(input, bias, output, split_start_points_[task_id], split_end_points_[task_id], inner_num_,
|
||||
batch_priority_);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, LiteKernelCreator<BiasCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -33,9 +33,17 @@ class BiasCPUKernel : public InnerKernel {
|
|||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoExecute(int task_id);
|
||||
|
||||
private:
|
||||
void GetThreadSegmentInfos();
|
||||
ArithmeticParameter *bias_param_;
|
||||
bool batch_priority_{false};
|
||||
int64_t inner_num_{0};
|
||||
int64_t outer_num_{0};
|
||||
int64_t total_num_{0};
|
||||
std::vector<int64_t> split_start_points_;
|
||||
std::vector<int64_t> split_end_points_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
Loading…
Reference in New Issue