diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 572fa0e9700..b3bd2c902c8 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -62,6 +62,7 @@ class LiteKernel { const lite::Primitive *primitive) : opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), context_(ctx) { + opParameter->thread_num_ = ctx->thread_num_; this->in_kernel_.clear(); this->out_kernel_.clear(); } @@ -69,12 +70,13 @@ class LiteKernel { virtual ~LiteKernel() { delete opParameter; } virtual int Prepare() { - if (primitive_ != nullptr && !primitive_->GetInferFlag()) { + if (!InferShapeDone()) { (const_cast(primitive_))->InferShape(inputs_, outputs_); + if (need_reinit) { + Init(); + } } - if (need_reinit) { - Init(); - } + auto &outputs = this->GetOutputs(); for (auto *output : outputs) { MS_ASSERT(output != nullptr); @@ -126,6 +128,13 @@ class LiteKernel { } protected: + bool InferShapeDone() { + if (primitive_ != nullptr && !primitive_->GetInferFlag()) { + return false; + } + return true; + } + KernelKey desc; std::string name; OpParameter *opParameter = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc index b3b6d4cba08..0230c7aea27 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc @@ -32,10 +32,6 @@ using mindspore::schema::PrimitiveType_ArgMin; namespace mindspore::kernel { int ArgMinMaxBaseCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - SetNeedReInit(); - return RET_OK; - } auto param = reinterpret_cast(opParameter); switch (opParameter->type_) { case PrimitiveType_ArgMax: @@ -49,8 +45,13 @@ int ArgMinMaxBaseCPUKernel::Init() { return RET_ERROR; } + return RET_OK; +} + +int ArgMinMaxBaseCPUKernel::ReSize() { auto in_shape = inputs_.at(0)->shape(); auto dims_size = in_shape.size(); + auto param = reinterpret_cast(opParameter); int axis = param->axis_ < 0 ? param->axis_ + dims_size : param->axis_; param->axis_ = axis; param->dims_size_ = dims_size; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h index 771ecce9025..9630fd4c79c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h @@ -26,15 +26,13 @@ class ArgMinMaxBaseCPUKernel : public LiteKernel { ArgMinMaxBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const lite::Primitive *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) { - opParameter->thread_num_ = ctx->thread_num_; - } + : LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) {} virtual ~ArgMinMaxBaseCPUKernel() { FreeTmpMemory(); } int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc index dc320b049f5..459364679e8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc @@ -30,10 +30,6 @@ using mindspore::schema::PrimitiveType_BatchToSpace; namespace mindspore::kernel { int BatchToSpaceBaseCPUKernel::Init() { - if (inputs_[0]->GetFormat() != schema::Format_NHWC) { - MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; - return RET_FORMAT_ERR; - } BatchToSpaceParameter *param = reinterpret_cast(this->opParameter); for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { if (param->crops_[i] != 0) { @@ -43,6 +39,14 @@ int BatchToSpaceBaseCPUKernel::Init() { return RET_OK; } +int BatchToSpaceBaseCPUKernel::ReSize() { + if (inputs_[0]->GetFormat() != schema::Format_NHWC) { + MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } + return RET_OK; +} + kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, diff --git a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h index e8e6f83ac97..8e7ca7299d7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h @@ -35,7 +35,7 @@ class BatchToSpaceBaseCPUKernel : public LiteKernel { int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override { return 0; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc index cd90425efac..cfb555e0258 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc @@ -31,11 +31,9 @@ using mindspore::lite::RET_PARAM_INVALID; using mindspore::schema::PrimitiveType_DepthToSpace; namespace mindspore::kernel { -int DepthToSpaceBaseCPUKernel::Init() { - if (context_->infer_shape_interrupt_ && !context_->running_) { - SetNeedReInit(); - return RET_OK; - } +int DepthToSpaceBaseCPUKernel::Init() { return RET_OK; } + +int DepthToSpaceBaseCPUKernel::ReSize() { if (inputs_[0]->GetFormat() != schema::Format_NHWC) { MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; return RET_FORMAT_ERR; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h index 849934fb121..27a2ae73ba9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h @@ -35,7 +35,7 @@ class DepthToSpaceBaseCPUKernel : public LiteKernel { int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override { return 0; } }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc index 8ba1595ed92..a9cb85147cd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc @@ -36,7 +36,15 @@ int ArgMinMaxCPUKernel::Init() { } auto param = reinterpret_cast(opParameter); param->data_type_ = kNumberTypeFloat32; - return RET_OK; + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int ArgMinMaxCPUKernel::ReSize() { + ArgMinMaxBaseCPUKernel::FreeTmpMemory(); + return ArgMinMaxBaseCPUKernel::ReSize(); } int ArgMinMaxCPUKernel::Run() { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h index fdc8fbdd3a8..611e28ea6ce 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h @@ -30,7 +30,7 @@ class ArgMinMaxCPUKernel : public ArgMinMaxBaseCPUKernel { ~ArgMinMaxCPUKernel() = default; int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc index 79b86e4d6de..e42193d9f9a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc @@ -24,7 +24,19 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int BatchToSpaceCPUKernel::Init() { - return BatchToSpaceBaseCPUKernel::Init(); + auto ret = BatchToSpaceBaseCPUKernel::Init(); + if (ret != RET_OK) { + return ret; + } + + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int BatchToSpaceCPUKernel::ReSize() { + return BatchToSpaceBaseCPUKernel::ReSize(); } int BatchToSpaceCPUKernel::Run() { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h index 2ac09c455ad..a8060726a6c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h @@ -29,7 +29,7 @@ class BatchToSpaceCPUKernel : public BatchToSpaceBaseCPUKernel { ~BatchToSpaceCPUKernel() = default; int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc index 1d29781a3cd..15d9a203c53 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc @@ -37,7 +37,15 @@ int DepthToSpaceCPUKernel::Init() { } DepthToSpaceParameter *param = reinterpret_cast(opParameter); param->data_type_size_ = sizeof(float); - return RET_OK; + if (!InferShapeDone()) { + return RET_OK; + } + + return ReSize(); +} + +int DepthToSpaceCPUKernel::ReSize() { + return DepthToSpaceBaseCPUKernel::ReSize(); } int DepthToSpaceCPUKernel::Run() { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h index d4f273ad325..706f094c374 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h @@ -29,7 +29,7 @@ class DepthToSpaceCPUKernel : public DepthToSpaceBaseCPUKernel { ~DepthToSpaceCPUKernel() = default; int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc index 36216531b59..a35d1575f9e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc @@ -40,14 +40,21 @@ int ArgMinMaxInt8CPUKernel::Init() { auto out_quant_args = out_tensor->GetQuantParams(); out_quant_arg_.scale_ = out_quant_args.front().scale; out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; - return RET_OK; + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int ArgMinMaxInt8CPUKernel::ReSize() { + return ArgMinMaxBaseCPUKernel::ReSize(); } int ArgMinMaxInt8CPUKernel::Run() { auto ret = Prepare(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Prepare failed."; - return RET_ERROR; + MS_LOG(ERROR) << "Prepare fail!ret: " << ret; + return ret; } auto input = inputs_.at(0); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h index 1a7e331b5fd..b8a8762637b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h @@ -31,7 +31,7 @@ class ArgMinMaxInt8CPUKernel : public ArgMinMaxBaseCPUKernel { ~ArgMinMaxInt8CPUKernel() = default; int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; private: QuantArg in_quant_arg_; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc index db0b7df7019..ac3a015bf13 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc @@ -38,7 +38,14 @@ int BatchToSpaceInt8CPUKernel::Init() { auto out_quant_args = out_tensor->GetQuantParams(); out_quant_arg_.scale_ = out_quant_args.front().scale; out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; - return RET_OK; + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int BatchToSpaceInt8CPUKernel::ReSize() { + return BatchToSpaceBaseCPUKernel::ReSize(); } int BatchToSpaceInt8CPUKernel::Run() { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h index 7755cbd09cd..94bb2280842 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h @@ -30,7 +30,7 @@ class BatchToSpaceInt8CPUKernel : public BatchToSpaceBaseCPUKernel { ~BatchToSpaceInt8CPUKernel() = default; int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; private: QuantArg in_quant_arg_; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc index 75af0bf7d79..4d535ad5622 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc @@ -42,7 +42,14 @@ int DepthToSpaceInt8CPUKernel::Init() { auto out_quant_args = out_tensor->GetQuantParams(); out_quant_arg_.scale_ = out_quant_args.front().scale; out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; - return RET_OK; + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int DepthToSpaceInt8CPUKernel::ReSize() { + return DepthToSpaceBaseCPUKernel::ReSize(); } int DepthToSpaceInt8CPUKernel::Run() { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h index 4b3520950ee..6548918ce36 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h @@ -30,7 +30,7 @@ class DepthToSpaceInt8CPUKernel : public DepthToSpaceBaseCPUKernel { ~DepthToSpaceInt8CPUKernel() = default; int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; private: QuantArg in_quant_arg_;