From f2addd3f6e903bca509a5f27f620efc2d13e1a8e Mon Sep 17 00:00:00 2001 From: chenjianping Date: Fri, 11 Sep 2020 16:46:06 +0800 Subject: [PATCH] refresh resize interface --- mindspore/lite/include/lite_session.h | 2 +- mindspore/lite/src/lite_session.cc | 35 ++++++++++++++++++++------- mindspore/lite/src/lite_session.h | 9 +++++-- mindspore/lite/src/scheduler.cc | 25 +++++++++++++------ 4 files changed, 52 insertions(+), 19 deletions(-) diff --git a/mindspore/lite/include/lite_session.h b/mindspore/lite/include/lite_session.h index 8bed3996a4e..865cbb08579 100644 --- a/mindspore/lite/include/lite_session.h +++ b/mindspore/lite/include/lite_session.h @@ -116,7 +116,7 @@ class MS_API LiteSession { /// \param[in] inputs Define the new inputs shape. /// /// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h. - virtual int Resize(const std::vector &inputs) = 0; + virtual int Resize(const std::vector &inputs, const std::vector>& dims) = 0; }; } // namespace session } // namespace mindspore diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index a2e8861f199..65b54687e52 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -389,34 +389,51 @@ std::unordered_map LiteSession::GetO return this->output_tensor_map_; } -int LiteSession::ResizeInputs(const std::vector &inputs) { +int LiteSession::ResizeInputs(const std::vector &inputs, + const std::vector> &dims) { if (inputs.size() != inputs_.size()) { MS_LOG(ERROR) << "Inputs size " << inputs.size() << " is not equal to " << inputs_.size(); return RET_PARAM_INVALID; } + if (dims.size() != inputs.size()) { + MS_LOG(ERROR) << "Input dims size " << dims.size() << " is not equal to the inputs size " << inputs.size(); + return RET_PARAM_INVALID; + } + for (size_t i = 0; i < inputs.size(); ++i) { - if (inputs[i] == nullptr) { - MS_LOG(ERROR) << "Input tensor is nullptr!"; + if (inputs[i] != inputs_[i]) { + MS_LOG(ERROR) << "Input[" << i << "] tensor is not equal to the inputs have been saved!"; return RET_PARAM_INVALID; } - inputs_[i]->set_shape(inputs[i]->shape()); + + inputs_[i]->set_shape(dims[i]); } return RET_OK; } -int LiteSession::Resize(const std::vector &inputs) { - std::vector inputs_old(inputs_); - auto ret = ResizeInputs(inputs); +void LiteSession::ResetInputsShape(const std::vector> &dims) { + for (size_t i = 0; i < inputs_.size(); ++i) { + inputs_[i]->set_shape(dims[i]); + } +} + +int LiteSession::Resize(const std::vector &inputs, + const std::vector> &dims) { + std::vector> old_dims; + for (size_t i = 0; i < inputs_.size(); ++i) { + old_dims.push_back(inputs_[i]->shape()); + } + auto ret = ResizeInputs(inputs, dims); if (ret != RET_OK) { - inputs_ = inputs_old; + ResetInputsShape(old_dims); return ret; } Scheduler scheduler(context_); ret = scheduler.ReSizeKernels(kernels_); if (ret != RET_OK) { - inputs_ = inputs_old; + ResetInputsShape(old_dims); auto resize_ret = scheduler.ReSizeKernels(kernels_); if (resize_ret != RET_OK) { MS_LOG(ERROR) << "restore kernel size fail!ret: " << resize_ret; diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index 95a3b98a0d6..bd4a3d047bb 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -59,7 +59,8 @@ class LiteSession : public session::LiteSession { std::unordered_map GetOutputs() const override; - int Resize(const std::vector &inputs) override; + int Resize(const std::vector &inputs, + const std::vector> &dims) override; protected: int ConvertTensors(const lite::Model *model); @@ -80,7 +81,11 @@ class LiteSession : public session::LiteSession { void InitGraphOutputTensorMap(const lite::Model *model); - int ResizeInputs(const std::vector &inputs); + int ResizeInputs(const std::vector &inputs, + const std::vector> &dims); + + private: + void ResetInputsShape(const std::vector> &dims); protected: Context *context_ = nullptr; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index e8ee1cec964..4b4eeaa19fa 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -52,6 +52,7 @@ int Scheduler::Schedule(const lite::Model *model, std::vector *tensors } int Scheduler::ReSizeKernels(const std::vector &kernels) { + bool infer_shape_interrupt = false; for (size_t i = 0; i < kernels.size(); ++i) { if (kernels[i] == nullptr) { MS_LOG(ERROR) << "input kernel is nullptr!"; @@ -64,15 +65,25 @@ int Scheduler::ReSizeKernels(const std::vector &kernels) { } std::vector &inputs = kernels[i]->in_tensors(); std::vector &outputs = kernels[i]->out_tensors(); + primitive->SetInferFlag(!infer_shape_interrupt); auto ret = primitive->InferShape(inputs, outputs); - if (ret != RET_OK) { - MS_LOG(ERROR) << "InferShape failed, name: " << kernels[i]->name() << ", ret = " << ret; - return ret; + if (ret == RET_INFER_INVALID) { + MS_LOG(INFO) << "InferShape shouldn't be done before runtime, type:" + << schema::EnumNamePrimitiveType(static_cast(primitive->Type())) + << "flag set to false."; + primitive->SetInferFlag(false); + infer_shape_interrupt = true; + } else if (ret != RET_OK) { + MS_LOG(ERROR) << "InferShape failed, type: " + << schema::EnumNamePrimitiveType(static_cast(primitive->Type())); + return RET_INFER_ERR; } - ret = kernels[i]->ReSize(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "kernel " << kernels[i]->name() << " resize fail!ret = " << ret; - return ret; + if (!infer_shape_interrupt) { + ret = kernels[i]->ReSize(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "kernel " << kernels[i]->name() << " resize fail!ret = " << ret; + return ret; + } } } return RET_OK;