bias_add_int8 op inherit add_int8 op

This commit is contained in:
fuzhiye 2021-02-24 14:42:35 +08:00
parent a38c996c9c
commit 6fbfe4b4ec
3 changed files with 5 additions and 65 deletions

View File

@ -33,7 +33,7 @@ class QuantizedAddCPUKernel : public LiteKernel {
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
arith_para_ = reinterpret_cast<ArithmeticParameter *>(parameter);
}
~QuantizedAddCPUKernel() override {}
~QuantizedAddCPUKernel() override = default;
int Init() override;
int ReSize() override;

View File

@ -15,10 +15,7 @@
*/
#include "src/runtime/kernel/arm/int8/bias_add_int8.h"
#include "nnacl/fp32/arithmetic_fp32.h"
#include "nnacl/errorcode.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
@ -26,52 +23,5 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BiasAdd;
namespace mindspore::kernel {
int BiasAddInt8CPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int BiasAddInt8CPUKernel::ReSize() {
auto bias_param = reinterpret_cast<ArithmeticParameter *>(op_parameter_);
auto dims = in_tensors_.at(0)->shape();
bias_param->ndim_ = dims.size();
if (bias_param->ndim_ < 1 || bias_param->ndim_ > 5) {
MS_LOG(ERROR) << "input shape is invalid";
return RET_ERROR;
}
for (size_t i = 0; i < bias_param->ndim_; i++) {
bias_param->in_shape0_[i] = dims[i];
bias_param->in_shape1_[i] = 1;
bias_param->out_shape_[i] = dims[i];
}
bias_param->in_shape1_[bias_param->ndim_ - 1] = dims[bias_param->ndim_ - 1];
return RET_OK;
}
int BiasAddInt8CPUKernel::Run() {
auto in = reinterpret_cast<int8_t *>(in_tensors_.at(0)->MutableData());
auto bias = reinterpret_cast<int8_t *>(in_tensors_.at(1)->MutableData());
auto out = reinterpret_cast<int8_t *>(out_tensors_.at(0)->MutableData());
size_t data_size = in_tensors_.at(0)->ElementsNum();
auto tile_in = static_cast<int8_t *>(ctx_->allocator->Malloc(data_size));
if (tile_in == nullptr) {
MS_LOG(ERROR) << "Failed to malloc momery";
return NNACL_ERR;
}
auto tile_bias = static_cast<int8_t *>(ctx_->allocator->Malloc(data_size));
if (tile_bias == nullptr) {
MS_LOG(ERROR) << "Failed to malloc momery";
ctx_->allocator->Free(tile_in);
return NNACL_ERR;
}
BroadcastAddInt8(in, bias, tile_in, tile_bias, out, data_size,
reinterpret_cast<ArithmeticParameter *>(op_parameter_));
ctx_->allocator->Free(tile_in);
ctx_->allocator->Free(tile_bias);
return NNACL_OK;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_BiasAdd, LiteKernelCreator<BiasAddInt8CPUKernel>)
} // namespace mindspore::kernel

View File

@ -17,26 +17,16 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BAIS_ADD_INT8_H_
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/arithmetic.h"
#include "nnacl/int8/add_int8.h"
#include "nnacl/int8/arithmetic_int8.h"
#include "src/runtime/kernel/arm/int8/add_int8.h"
namespace mindspore::kernel {
class BiasAddInt8CPUKernel : public LiteKernel {
class BiasAddInt8CPUKernel : public QuantizedAddCPUKernel {
public:
BiasAddInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx) {}
~BiasAddInt8CPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
private:
const lite::InnerContext *ctx_;
: QuantizedAddCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~BiasAddInt8CPUKernel() override = default;
};
} // namespace mindspore::kernel