!9420 [MSLITE] fp32 add optimize

From: @ling_qiao_min
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-04 10:43:35 +08:00 committed by Gitee
commit 3eb49a4a7e
7 changed files with 266 additions and 184 deletions

View File

@ -153,8 +153,8 @@ int TransposeFp16CPUKernel::Run() {
fp16_out_data_ = reinterpret_cast<float16_t *>(out_tensor->MutableData());
}
in_shape_ = const_cast<int *>(in_tensor->shape().data());
out_shape_ = const_cast<int *>(out_tensor->shape().data());
memcpy(in_shape_, in_tensor->shape().data(), in_tensor->shape().size() * sizeof(int));
memcpy(out_shape_, out_tensor->shape().data(), out_tensor->shape().size() * sizeof(int));
ret = ParallelLaunch(this->context_->thread_pool_, TransposeFp16Run, this, thread_h_num_);
if (ret != RET_OK) {

View File

@ -48,8 +48,8 @@ class TransposeFp16CPUKernel : public LiteKernel {
float *out_data_;
float16_t *fp16_in_data_ = nullptr;
float16_t *fp16_out_data_ = nullptr;
int *in_shape_;
int *out_shape_;
int in_shape_[8];
int out_shape_[8];
};
} // namespace mindspore::kernel

View File

@ -30,7 +30,10 @@ using mindspore::schema::PrimitiveType_Eltwise;
namespace mindspore::kernel {
ArithmeticCPUKernel::~ArithmeticCPUKernel() {}
ArithmeticCPUKernel::~ArithmeticCPUKernel() {
FreeTmpPtr();
return;
}
int ArithmeticCPUKernel::Init() {
if (!InferShapeDone()) {
@ -39,6 +42,59 @@ int ArithmeticCPUKernel::Init() {
return ReSize();
}
int ArithmeticCPUKernel::InitBroadCastCase() {
/* if const node need broadcast
* and all need-broadcast-node are const
* broadcast in resize */
if (arithmeticParameter_->broadcasting_ == false) {
return RET_OK;
}
int broadcast_size = out_tensors_[0]->Size();
if (broadcast_size < 0) {
return RET_OK;
}
if (arithmeticParameter_->in_elements_num0_ != arithmeticParameter_->out_elements_num_ &&
arithmeticParameter_->in_elements_num1_ != arithmeticParameter_->out_elements_num_) {
/* [1, 1, 2] + [1, 2, 1] -> [1, 2, 2]
* need broadcast both input */
return RET_OK;
}
FreeTmpPtr();
CalcMultiplesAndStrides(arithmeticParameter_);
if (in_tensors_[0]->data_c() != nullptr &&
arithmeticParameter_->in_elements_num1_ == arithmeticParameter_->out_elements_num_) {
input0_ptr_ = malloc(broadcast_size);
if (input0_ptr_ == nullptr) {
return RET_ERROR;
}
TileOneDimension(reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(input0_ptr_), 0,
arithmeticParameter_->ndim_, arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_,
arithmeticParameter_->out_strides_, arithmeticParameter_->multiples0_);
arithmeticParameter_->broadcasting_ = false;
input0_broadcast_ = true;
}
if (in_tensors_[1]->data_c() != nullptr &&
arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->out_elements_num_) {
input1_ptr_ = malloc(broadcast_size);
if (input1_ptr_ == nullptr) {
FreeTmpPtr();
return RET_ERROR;
}
TileOneDimension(reinterpret_cast<float *>(in_tensors_[1]->data_c()), reinterpret_cast<float *>(input1_ptr_), 0,
arithmeticParameter_->ndim_, arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_,
arithmeticParameter_->out_strides_, arithmeticParameter_->multiples1_);
arithmeticParameter_->broadcasting_ = false;
input1_broadcast_ = true;
}
return RET_OK;
}
int ArithmeticCPUKernel::PreProcess() {
if (!InferShapeDone()) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(true);
@ -73,25 +129,98 @@ int ArithmeticCPUKernel::PreProcess() {
return RET_OK;
}
int ArithmeticCPUKernel::ReSize() {
auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_;
arithmeticParameter_->broadcasting_ = arithmetic_lite_primitive->Broadcasting();
arithmeticParameter_->ndim_ = arithmetic_lite_primitive->NDims();
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) {
data_type_ = kDataTypeFloat;
} else {
data_type_ = kDataTypeInt;
void ArithmeticCPUKernel::InitRunFunction() {
switch (op_parameter_->type_) {
case PrimitiveType_Mul:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementMulRelu;
arithmetic_run_int_ = ElementMulReluInt;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementMulRelu6;
arithmetic_run_int_ = ElementMulRelu6Int;
break;
default:
arithmetic_run_ = ElementMul;
arithmetic_run_int_ = ElementMulInt;
break;
}
break;
case PrimitiveType_Add:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementAddRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementAddRelu6;
break;
default:
arithmetic_run_ = ElementAdd;
arithmetic_run_int_ = ElementAddInt;
break;
}
break;
case PrimitiveType_Sub:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementSubRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementSubRelu6;
break;
default:
arithmetic_run_ = ElementSub;
arithmetic_run_int_ = ElementSubInt;
break;
}
break;
case PrimitiveType_Div:
case PrimitiveType_RealDiv:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementDivRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementDivRelu6;
break;
default:
arithmetic_run_ = ElementDiv;
break;
}
break;
case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAnd;
break;
case PrimitiveType_LogicalOr:
arithmetic_run_ = ElementLogicalOr;
break;
case PrimitiveType_Maximum:
arithmetic_run_ = ElementMaximum;
break;
case PrimitiveType_Minimum:
arithmetic_run_ = ElementMinimum;
break;
case PrimitiveType_FloorDiv:
arithmetic_run_ = ElementFloorDiv;
arithmetic_run_int_ = ElementFloorDivInt;
break;
case PrimitiveType_FloorMod:
arithmetic_run_ = ElementFloorMod;
arithmetic_run_int_ = ElementFloorModInt;
break;
case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifference;
break;
default:
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
arithmetic_run_ = nullptr;
break;
}
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int));
memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int));
memcpy(arithmeticParameter_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int));
return;
}
void ArithmeticCPUKernel::InitOptRunFunction() {
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
switch (arithmeticParameter_->op_parameter_.type_) {
case PrimitiveType_Mul:
@ -163,23 +292,45 @@ int ArithmeticCPUKernel::ReSize() {
break;
}
break;
case PrimitiveType_Equal:
case PrimitiveType_Less:
case PrimitiveType_Greater:
case PrimitiveType_NotEqual:
case PrimitiveType_LessEqual:
case PrimitiveType_GreaterEqual:
default:
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
break;
default:
break;
}
} else {
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
}
return RET_OK;
return;
}
void ArithmeticCPUKernel::InitParam() {
auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_;
arithmeticParameter_->broadcasting_ = arithmetic_lite_primitive->Broadcasting();
arithmeticParameter_->ndim_ = arithmetic_lite_primitive->NDims();
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) {
data_type_ = kDataTypeFloat;
} else {
data_type_ = kDataTypeInt;
}
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int));
memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int));
memcpy(arithmeticParameter_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int));
return;
}
int ArithmeticCPUKernel::ReSize() {
InitParam();
InitOptRunFunction();
return InitBroadCastCase();
}
int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
@ -229,7 +380,8 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
}
int error_code;
if (arithmeticParameter_->broadcasting_) { // need broadcast
if (arithmeticParameter_->broadcasting_) {
/* need broadcast in runtime */
stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
if (out_count <= 0) {
@ -237,59 +389,57 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
}
int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(reinterpret_cast<float *>(in_tensors_[0]->data_c()),
reinterpret_cast<float *>(in_tensors_[1]->data_c()),
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<int *>(in_tensors_[0]->data_c()),
reinterpret_cast<int *>(in_tensors_[1]->data_c()),
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
}
return error_code;
}
} else if (arithmetic_opt_run_ != nullptr) { // no broadcast, one of input is scalar
if (arithmetic_opt_run_ != nullptr) {
/* run opt function
* one of input is scalar */
if (arithmeticParameter_->in_elements_num0_ == 1) {
if (data_type_ == kDataTypeFloat) {
error_code = arithmetic_opt_run_(reinterpret_cast<float *>(in_tensors_[0]->data_c()),
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id,
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count,
arithmeticParameter_);
error_code = arithmetic_opt_run_(
reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
} else {
error_code = arithmetic_opt_run_int_(reinterpret_cast<int *>(in_tensors_[0]->data_c()),
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id,
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id,
count, arithmeticParameter_);
error_code = arithmetic_opt_run_int_(
reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
}
} else if (arithmeticParameter_->in_elements_num1_ == 1) {
if (data_type_ == kDataTypeFloat) {
error_code = arithmetic_opt_run_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id,
reinterpret_cast<float *>(in_tensors_[1]->data_c()),
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count,
arithmeticParameter_);
error_code = arithmetic_opt_run_(
reinterpret_cast<float *>(input0_ptr_) + stride * task_id, reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
} else {
error_code = arithmetic_opt_run_int_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id,
reinterpret_cast<int *>(in_tensors_[1]->data_c()),
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id,
count, arithmeticParameter_);
error_code = arithmetic_opt_run_int_(
reinterpret_cast<int *>(input0_ptr_) + stride * task_id, reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
}
} else {
MS_LOG(ERROR) << "Arithmetic opt run: at least one of inputs is scalar";
return RET_ERROR;
}
} else { // no broadcast, neither is scalar, two same shape
if (data_type_ == kDataTypeFloat) {
error_code = arithmetic_run_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id,
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id,
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count);
} else {
error_code = arithmetic_run_int_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id,
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id,
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count);
}
return error_code;
}
if (error_code != RET_OK) {
return RET_ERROR;
/* no broadcast in runtime */
if (data_type_ == kDataTypeFloat) {
error_code = arithmetic_run_(reinterpret_cast<float *>(input0_ptr_) + stride * task_id,
reinterpret_cast<float *>(input1_ptr_) + stride * task_id,
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + stride * task_id, count);
} else {
error_code = arithmetic_run_int_(reinterpret_cast<int *>(input0_ptr_) + stride * task_id,
reinterpret_cast<int *>(input1_ptr_) + stride * task_id,
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count);
}
return RET_OK;
return error_code;
}
int ArithmeticsRun(void *cdata, int task_id) {
@ -302,7 +452,22 @@ int ArithmeticsRun(void *cdata, int task_id) {
return RET_OK;
}
int ArithmeticCPUKernel::Run() {
void ArithmeticCPUKernel::FreeTmpPtr() {
if (input0_broadcast_ == true && input0_ptr_ != nullptr) {
free(input0_ptr_);
input0_ptr_ = nullptr;
input0_broadcast_ = false;
}
if (input1_broadcast_ == true && input1_ptr_ != nullptr) {
free(input1_ptr_);
input1_ptr_ = nullptr;
input0_broadcast_ = false;
}
return;
}
void ArithmeticCPUKernel::InitParamInRunTime() {
/* after infershape */
if (arithmeticParameter_->broadcasting_) {
outside_ = 1;
for (auto i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) {
@ -312,13 +477,24 @@ int ArithmeticCPUKernel::Run() {
}
outside_ *= arithmeticParameter_->out_shape_[i];
}
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_);
}
ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_);
ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_);
if (!input0_broadcast_) {
input0_ptr_ = in_tensors_[0]->data_c();
}
if (!input1_broadcast_) {
input1_ptr_ = in_tensors_[1]->data_c();
}
return;
}
int ArithmeticCPUKernel::Run() {
InitParamInRunTime();
int error_code = ParallelLaunch(this->context_->thread_pool_, ArithmeticsRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Arithmetic function error error_code[" << error_code << "]";
return RET_ERROR;
@ -370,5 +546,4 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorDiv, CpuArithmeticFp32Kern
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Eltwise, CpuArithmeticFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -56,111 +56,7 @@ class ArithmeticCPUKernel : public LiteKernel {
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
switch (parameter->type_) {
case PrimitiveType_Mul:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementMulRelu;
arithmetic_run_int_ = ElementMulReluInt;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementMulRelu6;
arithmetic_run_int_ = ElementMulRelu6Int;
break;
default:
arithmetic_run_ = ElementMul;
arithmetic_run_int_ = ElementMulInt;
break;
}
break;
case PrimitiveType_Add:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementAddRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementAddRelu6;
break;
default:
arithmetic_run_ = ElementAdd;
arithmetic_run_int_ = ElementAddInt;
break;
}
break;
case PrimitiveType_Sub:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementSubRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementSubRelu6;
break;
default:
arithmetic_run_ = ElementSub;
arithmetic_run_int_ = ElementSubInt;
break;
}
break;
case PrimitiveType_Div:
case PrimitiveType_RealDiv:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementDivRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementDivRelu6;
break;
default:
arithmetic_run_ = ElementDiv;
break;
}
break;
case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAnd;
break;
case PrimitiveType_LogicalOr:
arithmetic_run_ = ElementLogicalOr;
break;
case PrimitiveType_Maximum:
arithmetic_run_ = ElementMaximum;
break;
case PrimitiveType_Minimum:
arithmetic_run_ = ElementMinimum;
break;
case PrimitiveType_FloorDiv:
arithmetic_run_ = ElementFloorDiv;
arithmetic_run_int_ = ElementFloorDivInt;
break;
case PrimitiveType_FloorMod:
arithmetic_run_ = ElementFloorMod;
arithmetic_run_int_ = ElementFloorModInt;
break;
case PrimitiveType_Equal:
arithmetic_run_ = ElementEqual;
break;
case PrimitiveType_NotEqual:
arithmetic_run_ = ElementNotEqual;
break;
case PrimitiveType_Less:
arithmetic_run_ = ElementLess;
break;
case PrimitiveType_LessEqual:
arithmetic_run_ = ElementLessEqual;
break;
case PrimitiveType_Greater:
arithmetic_run_ = ElementGreater;
break;
case PrimitiveType_GreaterEqual:
arithmetic_run_ = ElementGreaterEqual;
break;
case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifference;
break;
default:
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
arithmetic_run_ = nullptr;
break;
}
InitRunFunction();
}
~ArithmeticCPUKernel() override;
@ -171,6 +67,20 @@ class ArithmeticCPUKernel : public LiteKernel {
virtual int DoArithmetic(int task_id);
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
private:
void InitRunFunction();
void InitOptRunFunction();
void InitParam();
void FreeTmpPtr();
int InitBroadCastCase();
void InitParamInRunTime();
private:
bool input0_broadcast_ = false;
bool input1_broadcast_ = false;
void *input0_ptr_ = nullptr;
void *input1_ptr_ = nullptr;
protected:
int break_pos_ = 0;
int outside_ = 0;

View File

@ -263,7 +263,6 @@ int MatmulCPUKernel::RunImpl(int task_id) {
MS_ASSERT(cur_a_ptr_);
MS_ASSERT(b);
MS_ASSERT(c);
MS_ASSERT(bias);
if (is_vector_a_) {
MatVecMul(cur_a_ptr_, b, c, bias, ActType_No, params_->deep_, cur_oc);
} else {

View File

@ -143,8 +143,8 @@ int TransposeInt8CPUKernel::Run() {
in_ptr_ = reinterpret_cast<int8_t *>(in_tensor->data_c());
out_ptr_ = reinterpret_cast<int8_t *>(out_tensor->data_c());
in_shape_ = in_dims.data();
out_shape_ = out_dims.data();
memcpy(in_shape_, in_dims.data(), in_dims.size() * sizeof(int));
memcpy(out_shape_, out_dims.data(), out_dims.size() * sizeof(int));
int ret = MallocTmpBuf();
if (ret != RET_OK) {
@ -157,8 +157,6 @@ int TransposeInt8CPUKernel::Run() {
}
FreeTmpBuf();
in_shape_ = nullptr;
out_shape_ = nullptr;
return ret;
}

View File

@ -48,8 +48,6 @@ class TransposeInt8CPUKernel : public LiteKernel {
TransposeParameter *transpose_param_;
int8_t *in_ptr_ = nullptr;
int8_t *out_ptr_ = nullptr;
int *in_shape_ = nullptr;
int *out_shape_ = nullptr;
int *dim_size_ = nullptr;
int *position_ = nullptr;
bool extra_dims_ = false;
@ -57,6 +55,8 @@ class TransposeInt8CPUKernel : public LiteKernel {
int thread_h_stride_ = 0;
int thread_h_num_ = 0;
int num_unit_ = 0;
int in_shape_[8];
int out_shape_[8];
};
} // namespace mindspore::kernel