!19108 [MSLITE][Develop] add all tensors vector for delegate model
Merge pull request !19108 from yangruoqi713/delegate
This commit is contained in:
commit
e500e9e750
|
@ -29,9 +29,10 @@ using KernelIter = std::vector<kernel::Kernel *>::iterator;
|
|||
class DelegateModel {
|
||||
public:
|
||||
/// \brief Constructor of MindSpore Lite DelegateModel.
|
||||
DelegateModel(std::vector<kernel::Kernel *> *kernels,
|
||||
const std::map<kernel::Kernel *, const schema::Primitive *> primitives)
|
||||
: kernels_(kernels), primitives_(primitives) {}
|
||||
DelegateModel(std::vector<kernel::Kernel *> *kernels, const std::vector<tensor::MSTensor *> &inputs,
|
||||
const std::vector<tensor::MSTensor *> &outputs,
|
||||
const std::map<kernel::Kernel *, const schema::Primitive *> &primitives)
|
||||
: kernels_(kernels), inputs_(inputs), outputs_(outputs), primitives_(primitives) {}
|
||||
|
||||
/// \brief Destructor of MindSpore Lite DelegateModel.
|
||||
~DelegateModel() = default;
|
||||
|
@ -61,9 +62,15 @@ class DelegateModel {
|
|||
/// \return The next iterator after graph_kernel, point to the next kernel that is not visited.
|
||||
KernelIter Replace(KernelIter from, KernelIter end, kernel::Kernel *graph_kernel);
|
||||
|
||||
const std::vector<mindspore::tensor::MSTensor *> &inputs() { return this->inputs_; }
|
||||
|
||||
const std::vector<mindspore::tensor::MSTensor *> &outputs() { return this->outputs_; }
|
||||
|
||||
protected:
|
||||
std::vector<kernel::Kernel *> *kernels_;
|
||||
const std::map<kernel::Kernel *, const schema::Primitive *> primitives_;
|
||||
const std::vector<mindspore::tensor::MSTensor *> &inputs_;
|
||||
const std::vector<mindspore::tensor::MSTensor *> &outputs_;
|
||||
const std::map<kernel::Kernel *, const schema::Primitive *> &primitives_;
|
||||
};
|
||||
|
||||
typedef void (*DelegateHook)(std::shared_ptr<Delegate> delegate);
|
||||
|
|
|
@ -499,7 +499,7 @@ int LiteSession::CompileGraph(Model *model) {
|
|||
return ret;
|
||||
}
|
||||
// scheduler kernels
|
||||
Scheduler scheduler(context_, model, &tensors_, is_train_session_, delegate_);
|
||||
Scheduler scheduler(context_, model, &tensors_, inputs_, outputs_, is_train_session_, delegate_);
|
||||
scheduler.SetupSchedulerCb(std::move(sched_cb_));
|
||||
ret = scheduler.Schedule(&kernels_);
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -161,7 +161,17 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_ker
|
|||
for (size_t i = 0; i < dst_kernels->size(); i++) {
|
||||
kernels.push_back((*dst_kernels)[i]->kernel());
|
||||
}
|
||||
DelegateModel *model = new (std::nothrow) DelegateModel(&kernels, primitives_);
|
||||
|
||||
std::vector<tensor::MSTensor *> input_ms_tensors;
|
||||
input_ms_tensors.resize(inputs_.size());
|
||||
(void)std::transform(inputs_.begin(), inputs_.end(), input_ms_tensors.begin(),
|
||||
[](lite::Tensor *tensor) { return reinterpret_cast<tensor::MSTensor *>(tensor); });
|
||||
std::vector<tensor::MSTensor *> output_ms_tensors;
|
||||
output_ms_tensors.resize(outputs_.size());
|
||||
(void)std::transform(outputs_.begin(), outputs_.end(), output_ms_tensors.begin(),
|
||||
[](lite::Tensor *tensor) { return reinterpret_cast<tensor::MSTensor *>(tensor); });
|
||||
|
||||
DelegateModel *model = new (std::nothrow) DelegateModel(&kernels, input_ms_tensors, output_ms_tensors, primitives_);
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "New delegate model failed.";
|
||||
return RET_NULL_PTR;
|
||||
|
|
|
@ -34,11 +34,14 @@
|
|||
namespace mindspore::lite {
|
||||
class Scheduler {
|
||||
public:
|
||||
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors, bool is_train_session,
|
||||
std::shared_ptr<Delegate> delegate = nullptr)
|
||||
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors,
|
||||
const std::vector<Tensor *> &input_tensors, const std::vector<Tensor *> &output_tensors,
|
||||
bool is_train_session, std::shared_ptr<Delegate> delegate = nullptr)
|
||||
: context_(ctx),
|
||||
src_model_(src_model),
|
||||
src_tensors_(src_tensors),
|
||||
inputs_(input_tensors),
|
||||
outputs_(output_tensors),
|
||||
is_train_session_(is_train_session),
|
||||
delegate_(delegate) {}
|
||||
~Scheduler() = default;
|
||||
|
@ -111,6 +114,8 @@ class Scheduler {
|
|||
const InnerContext *context_ = nullptr;
|
||||
Model *src_model_ = nullptr;
|
||||
std::vector<Tensor *> *src_tensors_;
|
||||
const std::vector<Tensor *> &inputs_;
|
||||
const std::vector<Tensor *> &outputs_;
|
||||
std::vector<size_t> graph_output_node_indexes_;
|
||||
std::map<int, OpParameter *> op_parameters_;
|
||||
bool is_train_session_ = false;
|
||||
|
|
Loading…
Reference in New Issue