forked from mindspore-Ecosystem/mindspore
!4176 [MS][LITE][Develop]optimize infershape when running graph
Merge pull request !4176 from chenjianping/lite_dev2
This commit is contained in:
commit
05f405c0bc
|
@ -62,6 +62,7 @@ class LiteKernel {
|
|||
const lite::Primitive *primitive)
|
||||
: opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive),
|
||||
context_(ctx) {
|
||||
opParameter->thread_num_ = ctx->thread_num_;
|
||||
this->in_kernel_.clear();
|
||||
this->out_kernel_.clear();
|
||||
}
|
||||
|
@ -69,12 +70,13 @@ class LiteKernel {
|
|||
virtual ~LiteKernel() { delete opParameter; }
|
||||
|
||||
virtual int Prepare() {
|
||||
if (primitive_ != nullptr && !primitive_->GetInferFlag()) {
|
||||
if (!InferShapeDone()) {
|
||||
(const_cast<lite::Primitive *>(primitive_))->InferShape(inputs_, outputs_);
|
||||
if (need_reinit) {
|
||||
Init();
|
||||
}
|
||||
}
|
||||
if (need_reinit) {
|
||||
Init();
|
||||
}
|
||||
|
||||
auto &outputs = this->GetOutputs();
|
||||
for (auto *output : outputs) {
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
@ -126,6 +128,13 @@ class LiteKernel {
|
|||
}
|
||||
|
||||
protected:
|
||||
bool InferShapeDone() {
|
||||
if (primitive_ != nullptr && !primitive_->GetInferFlag()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
KernelKey desc;
|
||||
std::string name;
|
||||
OpParameter *opParameter = nullptr;
|
||||
|
|
|
@ -32,10 +32,6 @@ using mindspore::schema::PrimitiveType_ArgMin;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
int ArgMinMaxBaseCPUKernel::Init() {
|
||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||
SetNeedReInit();
|
||||
return RET_OK;
|
||||
}
|
||||
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
|
||||
switch (opParameter->type_) {
|
||||
case PrimitiveType_ArgMax:
|
||||
|
@ -49,8 +45,13 @@ int ArgMinMaxBaseCPUKernel::Init() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArgMinMaxBaseCPUKernel::ReSize() {
|
||||
auto in_shape = inputs_.at(0)->shape();
|
||||
auto dims_size = in_shape.size();
|
||||
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
|
||||
int axis = param->axis_ < 0 ? param->axis_ + dims_size : param->axis_;
|
||||
param->axis_ = axis;
|
||||
param->dims_size_ = dims_size;
|
||||
|
|
|
@ -26,15 +26,13 @@ class ArgMinMaxBaseCPUKernel : public LiteKernel {
|
|||
ArgMinMaxBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) {
|
||||
opParameter->thread_num_ = ctx->thread_num_;
|
||||
}
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) {}
|
||||
|
||||
virtual ~ArgMinMaxBaseCPUKernel() { FreeTmpMemory(); }
|
||||
|
||||
int Init() override;
|
||||
|
||||
int ReSize() override { return 0; }
|
||||
int ReSize() override;
|
||||
|
||||
int Run() override;
|
||||
|
||||
|
|
|
@ -30,10 +30,6 @@ using mindspore::schema::PrimitiveType_BatchToSpace;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
int BatchToSpaceBaseCPUKernel::Init() {
|
||||
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
|
||||
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
|
||||
return RET_FORMAT_ERR;
|
||||
}
|
||||
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->opParameter);
|
||||
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
|
||||
if (param->crops_[i] != 0) {
|
||||
|
@ -43,6 +39,14 @@ int BatchToSpaceBaseCPUKernel::Init() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int BatchToSpaceBaseCPUKernel::ReSize() {
|
||||
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
|
||||
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
|
||||
return RET_FORMAT_ERR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *op_parameter, const lite::Context *ctx,
|
||||
|
|
|
@ -35,7 +35,7 @@ class BatchToSpaceBaseCPUKernel : public LiteKernel {
|
|||
|
||||
int Init() override;
|
||||
|
||||
int ReSize() override { return 0; }
|
||||
int ReSize() override;
|
||||
|
||||
int Run() override { return 0; }
|
||||
|
||||
|
|
|
@ -31,11 +31,9 @@ using mindspore::lite::RET_PARAM_INVALID;
|
|||
using mindspore::schema::PrimitiveType_DepthToSpace;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int DepthToSpaceBaseCPUKernel::Init() {
|
||||
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
||||
SetNeedReInit();
|
||||
return RET_OK;
|
||||
}
|
||||
int DepthToSpaceBaseCPUKernel::Init() { return RET_OK; }
|
||||
|
||||
int DepthToSpaceBaseCPUKernel::ReSize() {
|
||||
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
|
||||
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
|
||||
return RET_FORMAT_ERR;
|
||||
|
|
|
@ -35,7 +35,7 @@ class DepthToSpaceBaseCPUKernel : public LiteKernel {
|
|||
|
||||
int Init() override;
|
||||
|
||||
int ReSize() override { return 0; }
|
||||
int ReSize() override;
|
||||
|
||||
int Run() override { return 0; }
|
||||
};
|
||||
|
|
|
@ -36,7 +36,15 @@ int ArgMinMaxCPUKernel::Init() {
|
|||
}
|
||||
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
|
||||
param->data_type_ = kNumberTypeFloat32;
|
||||
return RET_OK;
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int ArgMinMaxCPUKernel::ReSize() {
|
||||
ArgMinMaxBaseCPUKernel::FreeTmpMemory();
|
||||
return ArgMinMaxBaseCPUKernel::ReSize();
|
||||
}
|
||||
|
||||
int ArgMinMaxCPUKernel::Run() {
|
||||
|
|
|
@ -30,7 +30,7 @@ class ArgMinMaxCPUKernel : public ArgMinMaxBaseCPUKernel {
|
|||
~ArgMinMaxCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override { return 0; }
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -24,7 +24,19 @@ using mindspore::lite::RET_OK;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
int BatchToSpaceCPUKernel::Init() {
|
||||
return BatchToSpaceBaseCPUKernel::Init();
|
||||
auto ret = BatchToSpaceBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int BatchToSpaceCPUKernel::ReSize() {
|
||||
return BatchToSpaceBaseCPUKernel::ReSize();
|
||||
}
|
||||
|
||||
int BatchToSpaceCPUKernel::Run() {
|
||||
|
|
|
@ -29,7 +29,7 @@ class BatchToSpaceCPUKernel : public BatchToSpaceBaseCPUKernel {
|
|||
~BatchToSpaceCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override { return 0; }
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -37,7 +37,15 @@ int DepthToSpaceCPUKernel::Init() {
|
|||
}
|
||||
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter);
|
||||
param->data_type_size_ = sizeof(float);
|
||||
return RET_OK;
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int DepthToSpaceCPUKernel::ReSize() {
|
||||
return DepthToSpaceBaseCPUKernel::ReSize();
|
||||
}
|
||||
|
||||
int DepthToSpaceCPUKernel::Run() {
|
||||
|
|
|
@ -29,7 +29,7 @@ class DepthToSpaceCPUKernel : public DepthToSpaceBaseCPUKernel {
|
|||
~DepthToSpaceCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override { return 0; }
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -40,14 +40,21 @@ int ArgMinMaxInt8CPUKernel::Init() {
|
|||
auto out_quant_args = out_tensor->GetQuantParams();
|
||||
out_quant_arg_.scale_ = out_quant_args.front().scale;
|
||||
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
|
||||
return RET_OK;
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int ArgMinMaxInt8CPUKernel::ReSize() {
|
||||
return ArgMinMaxBaseCPUKernel::ReSize();
|
||||
}
|
||||
|
||||
int ArgMinMaxInt8CPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return RET_ERROR;
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
auto input = inputs_.at(0);
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class ArgMinMaxInt8CPUKernel : public ArgMinMaxBaseCPUKernel {
|
|||
~ArgMinMaxInt8CPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override { return 0; }
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
private:
|
||||
QuantArg in_quant_arg_;
|
||||
|
|
|
@ -38,7 +38,14 @@ int BatchToSpaceInt8CPUKernel::Init() {
|
|||
auto out_quant_args = out_tensor->GetQuantParams();
|
||||
out_quant_arg_.scale_ = out_quant_args.front().scale;
|
||||
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
|
||||
return RET_OK;
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int BatchToSpaceInt8CPUKernel::ReSize() {
|
||||
return BatchToSpaceBaseCPUKernel::ReSize();
|
||||
}
|
||||
|
||||
int BatchToSpaceInt8CPUKernel::Run() {
|
||||
|
|
|
@ -30,7 +30,7 @@ class BatchToSpaceInt8CPUKernel : public BatchToSpaceBaseCPUKernel {
|
|||
~BatchToSpaceInt8CPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override { return 0; }
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
private:
|
||||
QuantArg in_quant_arg_;
|
||||
|
|
|
@ -42,7 +42,14 @@ int DepthToSpaceInt8CPUKernel::Init() {
|
|||
auto out_quant_args = out_tensor->GetQuantParams();
|
||||
out_quant_arg_.scale_ = out_quant_args.front().scale;
|
||||
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
|
||||
return RET_OK;
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int DepthToSpaceInt8CPUKernel::ReSize() {
|
||||
return DepthToSpaceBaseCPUKernel::ReSize();
|
||||
}
|
||||
|
||||
int DepthToSpaceInt8CPUKernel::Run() {
|
||||
|
|
|
@ -30,7 +30,7 @@ class DepthToSpaceInt8CPUKernel : public DepthToSpaceBaseCPUKernel {
|
|||
~DepthToSpaceInt8CPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override { return 0; }
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
private:
|
||||
QuantArg in_quant_arg_;
|
||||
|
|
Loading…
Reference in New Issue