diff --git a/mindspore/lite/include/delegate.h b/mindspore/lite/include/delegate.h index c19f94a4d2f..b841e722d3b 100644 --- a/mindspore/lite/include/delegate.h +++ b/mindspore/lite/include/delegate.h @@ -29,9 +29,10 @@ using KernelIter = std::vector::iterator; class DelegateModel { public: /// \brief Constructor of MindSpore Lite DelegateModel. - DelegateModel(std::vector *kernels, - const std::map primitives) - : kernels_(kernels), primitives_(primitives) {} + DelegateModel(std::vector *kernels, const std::vector &inputs, + const std::vector &outputs, + const std::map &primitives) + : kernels_(kernels), inputs_(inputs), outputs_(outputs), primitives_(primitives) {} /// \brief Destructor of MindSpore Lite DelegateModel. ~DelegateModel() = default; @@ -61,9 +62,15 @@ class DelegateModel { /// \return The next iterator after graph_kernel, point to the next kernel that is not visited. KernelIter Replace(KernelIter from, KernelIter end, kernel::Kernel *graph_kernel); + const std::vector &inputs() { return this->inputs_; } + + const std::vector &outputs() { return this->outputs_; } + protected: std::vector *kernels_; - const std::map primitives_; + const std::vector &inputs_; + const std::vector &outputs_; + const std::map &primitives_; }; typedef void (*DelegateHook)(std::shared_ptr delegate); diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 4911a819b12..5c713fb929b 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -499,7 +499,7 @@ int LiteSession::CompileGraph(Model *model) { return ret; } // scheduler kernels - Scheduler scheduler(context_, model, &tensors_, is_train_session_, delegate_); + Scheduler scheduler(context_, model, &tensors_, inputs_, outputs_, is_train_session_, delegate_); scheduler.SetupSchedulerCb(std::move(sched_cb_)); ret = scheduler.Schedule(&kernels_); if (ret != RET_OK) { diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 84e30a09eb2..556f7e43246 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -161,7 +161,17 @@ int Scheduler::ReplaceDelegateKernels(std::vector *dst_ker for (size_t i = 0; i < dst_kernels->size(); i++) { kernels.push_back((*dst_kernels)[i]->kernel()); } - DelegateModel *model = new (std::nothrow) DelegateModel(&kernels, primitives_); + + std::vector input_ms_tensors; + input_ms_tensors.resize(inputs_.size()); + (void)std::transform(inputs_.begin(), inputs_.end(), input_ms_tensors.begin(), + [](lite::Tensor *tensor) { return reinterpret_cast(tensor); }); + std::vector output_ms_tensors; + output_ms_tensors.resize(outputs_.size()); + (void)std::transform(outputs_.begin(), outputs_.end(), output_ms_tensors.begin(), + [](lite::Tensor *tensor) { return reinterpret_cast(tensor); }); + + DelegateModel *model = new (std::nothrow) DelegateModel(&kernels, input_ms_tensors, output_ms_tensors, primitives_); if (model == nullptr) { MS_LOG(ERROR) << "New delegate model failed."; return RET_NULL_PTR; diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index e4f0bc606b1..a3c91ffaaad 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -34,11 +34,14 @@ namespace mindspore::lite { class Scheduler { public: - Scheduler(const InnerContext *ctx, Model *src_model, std::vector *src_tensors, bool is_train_session, - std::shared_ptr delegate = nullptr) + Scheduler(const InnerContext *ctx, Model *src_model, std::vector *src_tensors, + const std::vector &input_tensors, const std::vector &output_tensors, + bool is_train_session, std::shared_ptr delegate = nullptr) : context_(ctx), src_model_(src_model), src_tensors_(src_tensors), + inputs_(input_tensors), + outputs_(output_tensors), is_train_session_(is_train_session), delegate_(delegate) {} ~Scheduler() = default; @@ -111,6 +114,8 @@ class Scheduler { const InnerContext *context_ = nullptr; Model *src_model_ = nullptr; std::vector *src_tensors_; + const std::vector &inputs_; + const std::vector &outputs_; std::vector graph_output_node_indexes_; std::map op_parameters_; bool is_train_session_ = false;