forked from mindspore-Ecosystem/mindspore
bias_add_int8 op inherit add_int8 op
This commit is contained in:
parent
a38c996c9c
commit
6fbfe4b4ec
|
@ -33,7 +33,7 @@ class QuantizedAddCPUKernel : public LiteKernel {
|
|||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
arith_para_ = reinterpret_cast<ArithmeticParameter *>(parameter);
|
||||
}
|
||||
~QuantizedAddCPUKernel() override {}
|
||||
~QuantizedAddCPUKernel() override = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
|
|
|
@ -15,10 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/int8/bias_add_int8.h"
|
||||
#include "nnacl/fp32/arithmetic_fp32.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
|
@ -26,52 +23,5 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_BiasAdd;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int BiasAddInt8CPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int BiasAddInt8CPUKernel::ReSize() {
|
||||
auto bias_param = reinterpret_cast<ArithmeticParameter *>(op_parameter_);
|
||||
auto dims = in_tensors_.at(0)->shape();
|
||||
bias_param->ndim_ = dims.size();
|
||||
if (bias_param->ndim_ < 1 || bias_param->ndim_ > 5) {
|
||||
MS_LOG(ERROR) << "input shape is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < bias_param->ndim_; i++) {
|
||||
bias_param->in_shape0_[i] = dims[i];
|
||||
bias_param->in_shape1_[i] = 1;
|
||||
bias_param->out_shape_[i] = dims[i];
|
||||
}
|
||||
bias_param->in_shape1_[bias_param->ndim_ - 1] = dims[bias_param->ndim_ - 1];
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BiasAddInt8CPUKernel::Run() {
|
||||
auto in = reinterpret_cast<int8_t *>(in_tensors_.at(0)->MutableData());
|
||||
auto bias = reinterpret_cast<int8_t *>(in_tensors_.at(1)->MutableData());
|
||||
auto out = reinterpret_cast<int8_t *>(out_tensors_.at(0)->MutableData());
|
||||
size_t data_size = in_tensors_.at(0)->ElementsNum();
|
||||
auto tile_in = static_cast<int8_t *>(ctx_->allocator->Malloc(data_size));
|
||||
if (tile_in == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to malloc momery";
|
||||
return NNACL_ERR;
|
||||
}
|
||||
auto tile_bias = static_cast<int8_t *>(ctx_->allocator->Malloc(data_size));
|
||||
if (tile_bias == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to malloc momery";
|
||||
ctx_->allocator->Free(tile_in);
|
||||
return NNACL_ERR;
|
||||
}
|
||||
BroadcastAddInt8(in, bias, tile_in, tile_bias, out, data_size,
|
||||
reinterpret_cast<ArithmeticParameter *>(op_parameter_));
|
||||
ctx_->allocator->Free(tile_in);
|
||||
ctx_->allocator->Free(tile_bias);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_BiasAdd, LiteKernelCreator<BiasAddInt8CPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -17,26 +17,16 @@
|
|||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BAIS_ADD_INT8_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/arithmetic.h"
|
||||
#include "nnacl/int8/add_int8.h"
|
||||
#include "nnacl/int8/arithmetic_int8.h"
|
||||
#include "src/runtime/kernel/arm/int8/add_int8.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class BiasAddInt8CPUKernel : public LiteKernel {
|
||||
class BiasAddInt8CPUKernel : public QuantizedAddCPUKernel {
|
||||
public:
|
||||
BiasAddInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx) {}
|
||||
~BiasAddInt8CPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
const lite::InnerContext *ctx_;
|
||||
: QuantizedAddCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~BiasAddInt8CPUKernel() override = default;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
Loading…
Reference in New Issue