!19108 [MSLITE][Develop] add all tensors vector for delegate model

Merge pull request !19108 from yangruoqi713/delegate
This commit is contained in:
i-robot 2021-07-06 13:16:35 +00:00 committed by Gitee
commit e500e9e750
4 changed files with 30 additions and 8 deletions

View File

@ -29,9 +29,10 @@ using KernelIter = std::vector<kernel::Kernel *>::iterator;
class DelegateModel {
public:
/// \brief Constructor of MindSpore Lite DelegateModel.
DelegateModel(std::vector<kernel::Kernel *> *kernels,
const std::map<kernel::Kernel *, const schema::Primitive *> primitives)
: kernels_(kernels), primitives_(primitives) {}
DelegateModel(std::vector<kernel::Kernel *> *kernels, const std::vector<tensor::MSTensor *> &inputs,
const std::vector<tensor::MSTensor *> &outputs,
const std::map<kernel::Kernel *, const schema::Primitive *> &primitives)
: kernels_(kernels), inputs_(inputs), outputs_(outputs), primitives_(primitives) {}
/// \brief Destructor of MindSpore Lite DelegateModel.
~DelegateModel() = default;
@ -61,9 +62,15 @@ class DelegateModel {
/// \return The next iterator after graph_kernel, point to the next kernel that is not visited.
KernelIter Replace(KernelIter from, KernelIter end, kernel::Kernel *graph_kernel);
const std::vector<mindspore::tensor::MSTensor *> &inputs() { return this->inputs_; }
const std::vector<mindspore::tensor::MSTensor *> &outputs() { return this->outputs_; }
protected:
std::vector<kernel::Kernel *> *kernels_;
const std::map<kernel::Kernel *, const schema::Primitive *> primitives_;
const std::vector<mindspore::tensor::MSTensor *> &inputs_;
const std::vector<mindspore::tensor::MSTensor *> &outputs_;
const std::map<kernel::Kernel *, const schema::Primitive *> &primitives_;
};
typedef void (*DelegateHook)(std::shared_ptr<Delegate> delegate);

View File

@ -499,7 +499,7 @@ int LiteSession::CompileGraph(Model *model) {
return ret;
}
// scheduler kernels
Scheduler scheduler(context_, model, &tensors_, is_train_session_, delegate_);
Scheduler scheduler(context_, model, &tensors_, inputs_, outputs_, is_train_session_, delegate_);
scheduler.SetupSchedulerCb(std::move(sched_cb_));
ret = scheduler.Schedule(&kernels_);
if (ret != RET_OK) {

View File

@ -161,7 +161,17 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_ker
for (size_t i = 0; i < dst_kernels->size(); i++) {
kernels.push_back((*dst_kernels)[i]->kernel());
}
DelegateModel *model = new (std::nothrow) DelegateModel(&kernels, primitives_);
std::vector<tensor::MSTensor *> input_ms_tensors;
input_ms_tensors.resize(inputs_.size());
(void)std::transform(inputs_.begin(), inputs_.end(), input_ms_tensors.begin(),
[](lite::Tensor *tensor) { return reinterpret_cast<tensor::MSTensor *>(tensor); });
std::vector<tensor::MSTensor *> output_ms_tensors;
output_ms_tensors.resize(outputs_.size());
(void)std::transform(outputs_.begin(), outputs_.end(), output_ms_tensors.begin(),
[](lite::Tensor *tensor) { return reinterpret_cast<tensor::MSTensor *>(tensor); });
DelegateModel *model = new (std::nothrow) DelegateModel(&kernels, input_ms_tensors, output_ms_tensors, primitives_);
if (model == nullptr) {
MS_LOG(ERROR) << "New delegate model failed.";
return RET_NULL_PTR;

View File

@ -34,11 +34,14 @@
namespace mindspore::lite {
class Scheduler {
public:
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session,
std::shared_ptr<Delegate> delegate = nullptr)
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors,
const std::vector<Tensor *> &input_tensors, const std::vector<Tensor *> &output_tensors,
bool is_train_session, std::shared_ptr<Delegate> delegate = nullptr)
: context_(ctx),
src_model_(src_model),
src_tensors_(src_tensors),
inputs_(input_tensors),
outputs_(output_tensors),
is_train_session_(is_train_session),
delegate_(delegate) {}
~Scheduler() = default;
@ -111,6 +114,8 @@ class Scheduler {
const InnerContext *context_ = nullptr;
Model *src_model_ = nullptr;
std::vector<Tensor *> *src_tensors_;
const std::vector<Tensor *> &inputs_;
const std::vector<Tensor *> &outputs_;
std::vector<size_t> graph_output_node_indexes_;
std::map<int, OpParameter *> op_parameters_;
bool is_train_session_ = false;