forked from mindspore-Ecosystem/mindspore
refresh resize interface
This commit is contained in:
parent
be62fd7fa6
commit
f2addd3f6e
|
@ -116,7 +116,7 @@ class MS_API LiteSession {
|
||||||
/// \param[in] inputs Define the new inputs shape.
|
/// \param[in] inputs Define the new inputs shape.
|
||||||
///
|
///
|
||||||
/// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
|
/// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
|
||||||
virtual int Resize(const std::vector<tensor::MSTensor *> &inputs) = 0;
|
virtual int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>>& dims) = 0;
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -389,34 +389,51 @@ std::unordered_map<std::string, mindspore::tensor::MSTensor *> LiteSession::GetO
|
||||||
return this->output_tensor_map_;
|
return this->output_tensor_map_;
|
||||||
}
|
}
|
||||||
|
|
||||||
int LiteSession::ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs) {
|
int LiteSession::ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs,
|
||||||
|
const std::vector<std::vector<int>> &dims) {
|
||||||
if (inputs.size() != inputs_.size()) {
|
if (inputs.size() != inputs_.size()) {
|
||||||
MS_LOG(ERROR) << "Inputs size " << inputs.size() << " is not equal to " << inputs_.size();
|
MS_LOG(ERROR) << "Inputs size " << inputs.size() << " is not equal to " << inputs_.size();
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (dims.size() != inputs.size()) {
|
||||||
|
MS_LOG(ERROR) << "Input dims size " << dims.size() << " is not equal to the inputs size " << inputs.size();
|
||||||
|
return RET_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
if (inputs[i] == nullptr) {
|
if (inputs[i] != inputs_[i]) {
|
||||||
MS_LOG(ERROR) << "Input tensor is nullptr!";
|
MS_LOG(ERROR) << "Input[" << i << "] tensor is not equal to the inputs have been saved!";
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
inputs_[i]->set_shape(inputs[i]->shape());
|
|
||||||
|
inputs_[i]->set_shape(dims[i]);
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs) {
|
void LiteSession::ResetInputsShape(const std::vector<std::vector<int>> &dims) {
|
||||||
std::vector<Tensor *> inputs_old(inputs_);
|
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||||
auto ret = ResizeInputs(inputs);
|
inputs_[i]->set_shape(dims[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs,
|
||||||
|
const std::vector<std::vector<int>> &dims) {
|
||||||
|
std::vector<std::vector<int>> old_dims;
|
||||||
|
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||||
|
old_dims.push_back(inputs_[i]->shape());
|
||||||
|
}
|
||||||
|
auto ret = ResizeInputs(inputs, dims);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
inputs_ = inputs_old;
|
ResetInputsShape(old_dims);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
Scheduler scheduler(context_);
|
Scheduler scheduler(context_);
|
||||||
ret = scheduler.ReSizeKernels(kernels_);
|
ret = scheduler.ReSizeKernels(kernels_);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
inputs_ = inputs_old;
|
ResetInputsShape(old_dims);
|
||||||
auto resize_ret = scheduler.ReSizeKernels(kernels_);
|
auto resize_ret = scheduler.ReSizeKernels(kernels_);
|
||||||
if (resize_ret != RET_OK) {
|
if (resize_ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "restore kernel size fail!ret: " << resize_ret;
|
MS_LOG(ERROR) << "restore kernel size fail!ret: " << resize_ret;
|
||||||
|
|
|
@ -59,7 +59,8 @@ class LiteSession : public session::LiteSession {
|
||||||
|
|
||||||
std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputs() const override;
|
std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputs() const override;
|
||||||
|
|
||||||
int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs) override;
|
int Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs,
|
||||||
|
const std::vector<std::vector<int>> &dims) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
int ConvertTensors(const lite::Model *model);
|
int ConvertTensors(const lite::Model *model);
|
||||||
|
@ -80,7 +81,11 @@ class LiteSession : public session::LiteSession {
|
||||||
|
|
||||||
void InitGraphOutputTensorMap(const lite::Model *model);
|
void InitGraphOutputTensorMap(const lite::Model *model);
|
||||||
|
|
||||||
int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs);
|
int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs,
|
||||||
|
const std::vector<std::vector<int>> &dims);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void ResetInputsShape(const std::vector<std::vector<int>> &dims);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Context *context_ = nullptr;
|
Context *context_ = nullptr;
|
||||||
|
|
|
@ -52,6 +52,7 @@ int Scheduler::Schedule(const lite::Model *model, std::vector<Tensor *> *tensors
|
||||||
}
|
}
|
||||||
|
|
||||||
int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) {
|
int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) {
|
||||||
|
bool infer_shape_interrupt = false;
|
||||||
for (size_t i = 0; i < kernels.size(); ++i) {
|
for (size_t i = 0; i < kernels.size(); ++i) {
|
||||||
if (kernels[i] == nullptr) {
|
if (kernels[i] == nullptr) {
|
||||||
MS_LOG(ERROR) << "input kernel is nullptr!";
|
MS_LOG(ERROR) << "input kernel is nullptr!";
|
||||||
|
@ -64,15 +65,25 @@ int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) {
|
||||||
}
|
}
|
||||||
std::vector<Tensor *> &inputs = kernels[i]->in_tensors();
|
std::vector<Tensor *> &inputs = kernels[i]->in_tensors();
|
||||||
std::vector<Tensor *> &outputs = kernels[i]->out_tensors();
|
std::vector<Tensor *> &outputs = kernels[i]->out_tensors();
|
||||||
|
primitive->SetInferFlag(!infer_shape_interrupt);
|
||||||
auto ret = primitive->InferShape(inputs, outputs);
|
auto ret = primitive->InferShape(inputs, outputs);
|
||||||
if (ret != RET_OK) {
|
if (ret == RET_INFER_INVALID) {
|
||||||
MS_LOG(ERROR) << "InferShape failed, name: " << kernels[i]->name() << ", ret = " << ret;
|
MS_LOG(INFO) << "InferShape shouldn't be done before runtime, type:"
|
||||||
return ret;
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()))
|
||||||
|
<< "flag set to false.";
|
||||||
|
primitive->SetInferFlag(false);
|
||||||
|
infer_shape_interrupt = true;
|
||||||
|
} else if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "InferShape failed, type: "
|
||||||
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type()));
|
||||||
|
return RET_INFER_ERR;
|
||||||
}
|
}
|
||||||
ret = kernels[i]->ReSize();
|
if (!infer_shape_interrupt) {
|
||||||
if (ret != RET_OK) {
|
ret = kernels[i]->ReSize();
|
||||||
MS_LOG(ERROR) << "kernel " << kernels[i]->name() << " resize fail!ret = " << ret;
|
if (ret != RET_OK) {
|
||||||
return ret;
|
MS_LOG(ERROR) << "kernel " << kernels[i]->name() << " resize fail!ret = " << ret;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
|
Loading…
Reference in New Issue