forked from mindspore-Ecosystem/mindspore
!9420 [MSLITE] fp32 add optimize
From: @ling_qiao_min Reviewed-by: Signed-off-by:
This commit is contained in:
commit
3eb49a4a7e
|
@ -153,8 +153,8 @@ int TransposeFp16CPUKernel::Run() {
|
|||
fp16_out_data_ = reinterpret_cast<float16_t *>(out_tensor->MutableData());
|
||||
}
|
||||
|
||||
in_shape_ = const_cast<int *>(in_tensor->shape().data());
|
||||
out_shape_ = const_cast<int *>(out_tensor->shape().data());
|
||||
memcpy(in_shape_, in_tensor->shape().data(), in_tensor->shape().size() * sizeof(int));
|
||||
memcpy(out_shape_, out_tensor->shape().data(), out_tensor->shape().size() * sizeof(int));
|
||||
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, TransposeFp16Run, this, thread_h_num_);
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -48,8 +48,8 @@ class TransposeFp16CPUKernel : public LiteKernel {
|
|||
float *out_data_;
|
||||
float16_t *fp16_in_data_ = nullptr;
|
||||
float16_t *fp16_out_data_ = nullptr;
|
||||
int *in_shape_;
|
||||
int *out_shape_;
|
||||
int in_shape_[8];
|
||||
int out_shape_[8];
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -30,7 +30,10 @@ using mindspore::schema::PrimitiveType_Eltwise;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
ArithmeticCPUKernel::~ArithmeticCPUKernel() {}
|
||||
ArithmeticCPUKernel::~ArithmeticCPUKernel() {
|
||||
FreeTmpPtr();
|
||||
return;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
|
@ -39,6 +42,59 @@ int ArithmeticCPUKernel::Init() {
|
|||
return ReSize();
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::InitBroadCastCase() {
|
||||
/* if const node need broadcast
|
||||
* and all need-broadcast-node are const
|
||||
* broadcast in resize */
|
||||
|
||||
if (arithmeticParameter_->broadcasting_ == false) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int broadcast_size = out_tensors_[0]->Size();
|
||||
if (broadcast_size < 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (arithmeticParameter_->in_elements_num0_ != arithmeticParameter_->out_elements_num_ &&
|
||||
arithmeticParameter_->in_elements_num1_ != arithmeticParameter_->out_elements_num_) {
|
||||
/* [1, 1, 2] + [1, 2, 1] -> [1, 2, 2]
|
||||
* need broadcast both input */
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
FreeTmpPtr();
|
||||
|
||||
CalcMultiplesAndStrides(arithmeticParameter_);
|
||||
|
||||
if (in_tensors_[0]->data_c() != nullptr &&
|
||||
arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) {
|
||||
input0_ptr_ = malloc(broadcast_size);
|
||||
if (input0_ptr_ == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
TileOneDimension(reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(input0_ptr_), 0,
|
||||
arithmeticParameter_->ndim_, arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_,
|
||||
arithmeticParameter_->out_strides_, arithmeticParameter_->multiples0_);
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
input0_broadcast_ = true;
|
||||
}
|
||||
if (in_tensors_[1]->data_c() != nullptr &&
|
||||
arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) {
|
||||
input1_ptr_ = malloc(broadcast_size);
|
||||
if (input1_ptr_ == nullptr) {
|
||||
FreeTmpPtr();
|
||||
return RET_ERROR;
|
||||
}
|
||||
TileOneDimension(reinterpret_cast<float *>(in_tensors_[1]->data_c()), reinterpret_cast<float *>(input1_ptr_), 0,
|
||||
arithmeticParameter_->ndim_, arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_,
|
||||
arithmeticParameter_->out_strides_, arithmeticParameter_->multiples1_);
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
input1_broadcast_ = true;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::PreProcess() {
|
||||
if (!InferShapeDone()) {
|
||||
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(true);
|
||||
|
@ -73,25 +129,98 @@ int ArithmeticCPUKernel::PreProcess() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::ReSize() {
|
||||
auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_;
|
||||
arithmeticParameter_->broadcasting_ = arithmetic_lite_primitive->Broadcasting();
|
||||
arithmeticParameter_->ndim_ = arithmetic_lite_primitive->NDims();
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) {
|
||||
data_type_ = kDataTypeFloat;
|
||||
} else {
|
||||
data_type_ = kDataTypeInt;
|
||||
void ArithmeticCPUKernel::InitRunFunction() {
|
||||
switch (op_parameter_->type_) {
|
||||
case PrimitiveType_Mul:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementMulRelu;
|
||||
arithmetic_run_int_ = ElementMulReluInt;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementMulRelu6;
|
||||
arithmetic_run_int_ = ElementMulRelu6Int;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementMul;
|
||||
arithmetic_run_int_ = ElementMulInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Add:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementAddRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementAddRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementAdd;
|
||||
arithmetic_run_int_ = ElementAddInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Sub:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementSubRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementSubRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementSub;
|
||||
arithmetic_run_int_ = ElementSubInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Div:
|
||||
case PrimitiveType_RealDiv:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementDivRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementDivRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementDiv;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_LogicalAnd:
|
||||
arithmetic_run_ = ElementLogicalAnd;
|
||||
break;
|
||||
case PrimitiveType_LogicalOr:
|
||||
arithmetic_run_ = ElementLogicalOr;
|
||||
break;
|
||||
case PrimitiveType_Maximum:
|
||||
arithmetic_run_ = ElementMaximum;
|
||||
break;
|
||||
case PrimitiveType_Minimum:
|
||||
arithmetic_run_ = ElementMinimum;
|
||||
break;
|
||||
case PrimitiveType_FloorDiv:
|
||||
arithmetic_run_ = ElementFloorDiv;
|
||||
arithmetic_run_int_ = ElementFloorDivInt;
|
||||
break;
|
||||
case PrimitiveType_FloorMod:
|
||||
arithmetic_run_ = ElementFloorMod;
|
||||
arithmetic_run_int_ = ElementFloorModInt;
|
||||
break;
|
||||
case PrimitiveType_SquaredDifference:
|
||||
arithmetic_run_ = ElementSquaredDifference;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
|
||||
arithmetic_run_ = nullptr;
|
||||
break;
|
||||
}
|
||||
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
|
||||
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
||||
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
||||
memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(),
|
||||
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int));
|
||||
memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(),
|
||||
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int));
|
||||
memcpy(arithmeticParameter_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(),
|
||||
reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int));
|
||||
return;
|
||||
}
|
||||
|
||||
void ArithmeticCPUKernel::InitOptRunFunction() {
|
||||
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
|
||||
switch (arithmeticParameter_->op_parameter_.type_) {
|
||||
case PrimitiveType_Mul:
|
||||
|
@ -163,23 +292,45 @@ int ArithmeticCPUKernel::ReSize() {
|
|||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Equal:
|
||||
case PrimitiveType_Less:
|
||||
case PrimitiveType_Greater:
|
||||
case PrimitiveType_NotEqual:
|
||||
case PrimitiveType_LessEqual:
|
||||
case PrimitiveType_GreaterEqual:
|
||||
default:
|
||||
arithmetic_opt_run_ = nullptr;
|
||||
arithmetic_opt_run_int_ = nullptr;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
arithmetic_opt_run_ = nullptr;
|
||||
arithmetic_opt_run_int_ = nullptr;
|
||||
}
|
||||
return RET_OK;
|
||||
return;
|
||||
}
|
||||
|
||||
void ArithmeticCPUKernel::InitParam() {
|
||||
auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_;
|
||||
arithmeticParameter_->broadcasting_ = arithmetic_lite_primitive->Broadcasting();
|
||||
arithmeticParameter_->ndim_ = arithmetic_lite_primitive->NDims();
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) {
|
||||
data_type_ = kDataTypeFloat;
|
||||
} else {
|
||||
data_type_ = kDataTypeInt;
|
||||
}
|
||||
|
||||
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
|
||||
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
||||
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
||||
memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(),
|
||||
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int));
|
||||
memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(),
|
||||
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int));
|
||||
memcpy(arithmeticParameter_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(),
|
||||
reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::ReSize() {
|
||||
InitParam();
|
||||
InitOptRunFunction();
|
||||
return InitBroadCastCase();
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
|
||||
|
@ -229,7 +380,8 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
}
|
||||
|
||||
int error_code;
|
||||
if (arithmeticParameter_->broadcasting_) { // need broadcast
|
||||
if (arithmeticParameter_->broadcasting_) {
|
||||
/* need broadcast in runtime */
|
||||
stride = UP_DIV(outside_, thread_count_);
|
||||
int out_count = MSMIN(stride, outside_ - stride * task_id);
|
||||
if (out_count <= 0) {
|
||||
|
@ -237,59 +389,57 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
}
|
||||
int out_thread_stride = stride * task_id;
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = BroadcastRun(reinterpret_cast<float *>(in_tensors_[0]->data_c()),
|
||||
reinterpret_cast<float *>(in_tensors_[1]->data_c()),
|
||||
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
|
||||
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
} else {
|
||||
error_code = BroadcastRun(reinterpret_cast<int *>(in_tensors_[0]->data_c()),
|
||||
reinterpret_cast<int *>(in_tensors_[1]->data_c()),
|
||||
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
}
|
||||
return error_code;
|
||||
}
|
||||
|
||||
} else if (arithmetic_opt_run_ != nullptr) { // no broadcast, one of input is scalar
|
||||
if (arithmetic_opt_run_ != nullptr) {
|
||||
/* run opt function
|
||||
* one of input is scalar */
|
||||
if (arithmeticParameter_->in_elements_num0_ == 1) {
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = arithmetic_opt_run_(reinterpret_cast<float *>(in_tensors_[0]->data_c()),
|
||||
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count,
|
||||
arithmeticParameter_);
|
||||
error_code = arithmetic_opt_run_(
|
||||
reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
|
||||
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
|
||||
} else {
|
||||
error_code = arithmetic_opt_run_int_(reinterpret_cast<int *>(in_tensors_[0]->data_c()),
|
||||
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id,
|
||||
count, arithmeticParameter_);
|
||||
error_code = arithmetic_opt_run_int_(
|
||||
reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
|
||||
}
|
||||
} else if (arithmeticParameter_->in_elements_num1_ == 1) {
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = arithmetic_opt_run_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<float *>(in_tensors_[1]->data_c()),
|
||||
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count,
|
||||
arithmeticParameter_);
|
||||
error_code = arithmetic_opt_run_(
|
||||
reinterpret_cast<float *>(input0_ptr_) + stride * task_id, reinterpret_cast<float *>(input1_ptr_),
|
||||
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
|
||||
} else {
|
||||
error_code = arithmetic_opt_run_int_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<int *>(in_tensors_[1]->data_c()),
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id,
|
||||
count, arithmeticParameter_);
|
||||
error_code = arithmetic_opt_run_int_(
|
||||
reinterpret_cast<int *>(input0_ptr_) + stride * task_id, reinterpret_cast<int *>(input1_ptr_),
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Arithmetic opt run: at least one of inputs is scalar";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else { // no broadcast, neither is scalar, two same shape
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = arithmetic_run_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
||||
} else {
|
||||
error_code = arithmetic_run_int_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id,
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
||||
}
|
||||
|
||||
return error_code;
|
||||
}
|
||||
if (error_code != RET_OK) {
|
||||
return RET_ERROR;
|
||||
|
||||
/* no broadcast in runtime */
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = arithmetic_run_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id,
|
||||
reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
|
||||
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
||||
} else {
|
||||
error_code = arithmetic_run_int_(reinterpret_cast<int *>(input0_ptr_) + stride * task_id,
|
||||
reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count);
|
||||
}
|
||||
return RET_OK;
|
||||
return error_code;
|
||||
}
|
||||
|
||||
int ArithmeticsRun(void *cdata, int task_id) {
|
||||
|
@ -302,7 +452,22 @@ int ArithmeticsRun(void *cdata, int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::Run() {
|
||||
void ArithmeticCPUKernel::FreeTmpPtr() {
|
||||
if (input0_broadcast_ == true && input0_ptr_ != nullptr) {
|
||||
free(input0_ptr_);
|
||||
input0_ptr_ = nullptr;
|
||||
input0_broadcast_ = false;
|
||||
}
|
||||
if (input1_broadcast_ == true && input1_ptr_ != nullptr) {
|
||||
free(input1_ptr_);
|
||||
input1_ptr_ = nullptr;
|
||||
input0_broadcast_ = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ArithmeticCPUKernel::InitParamInRunTime() {
|
||||
/* after infershape */
|
||||
if (arithmeticParameter_->broadcasting_) {
|
||||
outside_ = 1;
|
||||
for (auto i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) {
|
||||
|
@ -312,13 +477,24 @@ int ArithmeticCPUKernel::Run() {
|
|||
}
|
||||
outside_ *= arithmeticParameter_->out_shape_[i];
|
||||
}
|
||||
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_);
|
||||
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_);
|
||||
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_);
|
||||
}
|
||||
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_);
|
||||
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_);
|
||||
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_);
|
||||
|
||||
if (!input0_broadcast_) {
|
||||
input0_ptr_ = in_tensors_[0]->data_c();
|
||||
}
|
||||
if (!input1_broadcast_) {
|
||||
input1_ptr_ = in_tensors_[1]->data_c();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::Run() {
|
||||
InitParamInRunTime();
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRun, this, thread_count_);
|
||||
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Arithmetic function error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
|
@ -370,5 +546,4 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorDiv, CpuArithmeticFp32Kern
|
|||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Eltwise, CpuArithmeticFp32KernelCreator)
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -56,111 +56,7 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
|
||||
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
|
||||
switch (parameter->type_) {
|
||||
case PrimitiveType_Mul:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementMulRelu;
|
||||
arithmetic_run_int_ = ElementMulReluInt;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementMulRelu6;
|
||||
arithmetic_run_int_ = ElementMulRelu6Int;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementMul;
|
||||
arithmetic_run_int_ = ElementMulInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Add:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementAddRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementAddRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementAdd;
|
||||
arithmetic_run_int_ = ElementAddInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Sub:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementSubRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementSubRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementSub;
|
||||
arithmetic_run_int_ = ElementSubInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Div:
|
||||
case PrimitiveType_RealDiv:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementDivRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementDivRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementDiv;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_LogicalAnd:
|
||||
arithmetic_run_ = ElementLogicalAnd;
|
||||
break;
|
||||
case PrimitiveType_LogicalOr:
|
||||
arithmetic_run_ = ElementLogicalOr;
|
||||
break;
|
||||
case PrimitiveType_Maximum:
|
||||
arithmetic_run_ = ElementMaximum;
|
||||
break;
|
||||
case PrimitiveType_Minimum:
|
||||
arithmetic_run_ = ElementMinimum;
|
||||
break;
|
||||
case PrimitiveType_FloorDiv:
|
||||
arithmetic_run_ = ElementFloorDiv;
|
||||
arithmetic_run_int_ = ElementFloorDivInt;
|
||||
break;
|
||||
case PrimitiveType_FloorMod:
|
||||
arithmetic_run_ = ElementFloorMod;
|
||||
arithmetic_run_int_ = ElementFloorModInt;
|
||||
break;
|
||||
case PrimitiveType_Equal:
|
||||
arithmetic_run_ = ElementEqual;
|
||||
break;
|
||||
case PrimitiveType_NotEqual:
|
||||
arithmetic_run_ = ElementNotEqual;
|
||||
break;
|
||||
case PrimitiveType_Less:
|
||||
arithmetic_run_ = ElementLess;
|
||||
break;
|
||||
case PrimitiveType_LessEqual:
|
||||
arithmetic_run_ = ElementLessEqual;
|
||||
break;
|
||||
case PrimitiveType_Greater:
|
||||
arithmetic_run_ = ElementGreater;
|
||||
break;
|
||||
case PrimitiveType_GreaterEqual:
|
||||
arithmetic_run_ = ElementGreaterEqual;
|
||||
break;
|
||||
case PrimitiveType_SquaredDifference:
|
||||
arithmetic_run_ = ElementSquaredDifference;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
|
||||
arithmetic_run_ = nullptr;
|
||||
break;
|
||||
}
|
||||
InitRunFunction();
|
||||
}
|
||||
~ArithmeticCPUKernel() override;
|
||||
|
||||
|
@ -171,6 +67,20 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
virtual int DoArithmetic(int task_id);
|
||||
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
|
||||
|
||||
private:
|
||||
void InitRunFunction();
|
||||
void InitOptRunFunction();
|
||||
void InitParam();
|
||||
void FreeTmpPtr();
|
||||
int InitBroadCastCase();
|
||||
void InitParamInRunTime();
|
||||
|
||||
private:
|
||||
bool input0_broadcast_ = false;
|
||||
bool input1_broadcast_ = false;
|
||||
void *input0_ptr_ = nullptr;
|
||||
void *input1_ptr_ = nullptr;
|
||||
|
||||
protected:
|
||||
int break_pos_ = 0;
|
||||
int outside_ = 0;
|
||||
|
|
|
@ -263,7 +263,6 @@ int MatmulCPUKernel::RunImpl(int task_id) {
|
|||
MS_ASSERT(cur_a_ptr_);
|
||||
MS_ASSERT(b);
|
||||
MS_ASSERT(c);
|
||||
MS_ASSERT(bias);
|
||||
if (is_vector_a_) {
|
||||
MatVecMul(cur_a_ptr_, b, c, bias, ActType_No, params_->deep_, cur_oc);
|
||||
} else {
|
||||
|
|
|
@ -143,8 +143,8 @@ int TransposeInt8CPUKernel::Run() {
|
|||
in_ptr_ = reinterpret_cast<int8_t *>(in_tensor->data_c());
|
||||
out_ptr_ = reinterpret_cast<int8_t *>(out_tensor->data_c());
|
||||
|
||||
in_shape_ = in_dims.data();
|
||||
out_shape_ = out_dims.data();
|
||||
memcpy(in_shape_, in_dims.data(), in_dims.size() * sizeof(int));
|
||||
memcpy(out_shape_, out_dims.data(), out_dims.size() * sizeof(int));
|
||||
|
||||
int ret = MallocTmpBuf();
|
||||
if (ret != RET_OK) {
|
||||
|
@ -157,8 +157,6 @@ int TransposeInt8CPUKernel::Run() {
|
|||
}
|
||||
|
||||
FreeTmpBuf();
|
||||
in_shape_ = nullptr;
|
||||
out_shape_ = nullptr;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -48,8 +48,6 @@ class TransposeInt8CPUKernel : public LiteKernel {
|
|||
TransposeParameter *transpose_param_;
|
||||
int8_t *in_ptr_ = nullptr;
|
||||
int8_t *out_ptr_ = nullptr;
|
||||
int *in_shape_ = nullptr;
|
||||
int *out_shape_ = nullptr;
|
||||
int *dim_size_ = nullptr;
|
||||
int *position_ = nullptr;
|
||||
bool extra_dims_ = false;
|
||||
|
@ -57,6 +55,8 @@ class TransposeInt8CPUKernel : public LiteKernel {
|
|||
int thread_h_stride_ = 0;
|
||||
int thread_h_num_ = 0;
|
||||
int num_unit_ = 0;
|
||||
int in_shape_[8];
|
||||
int out_shape_[8];
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
Loading…
Reference in New Issue