diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index dbff82263b1..997878b5988 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -181,17 +181,30 @@ std::vector LiteKernelUtil::SubgraphOutputKernels( std::vector LiteKernelUtil::SubgraphInputTensors(const std::vector &kernels) { std::vector input_tensors; - std::vector all_output_tensors; - for (const auto &kernel : kernels) { - auto kernel_out_tensors = kernel->out_tensors(); - all_output_tensors.insert(all_output_tensors.end(), kernel_out_tensors.begin(), kernel_out_tensors.end()); - } std::vector input_kernels = SubgraphInputKernels(kernels); - for (const auto &kernel : input_kernels) { - for (const auto &tensor : kernel->in_tensors()) { - auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor); - if (iter == all_output_tensors.end() && !tensor->IsConst()) { - input_tensors.emplace_back(tensor); + for (const auto &input_kernel : input_kernels) { + auto &outer_in_kernels = input_kernel->in_kernels(); + auto &in_kernel_in_tensors = input_kernel->in_tensors(); + if (outer_in_kernels.empty()) { + for (auto &in_kernel_in_tensor : in_kernel_in_tensors) { + if (!in_kernel_in_tensor->IsConst()) { + input_tensors.push_back(in_kernel_in_tensor); + } + } + continue; + } + for (auto outer_in_kernel : outer_in_kernels) { + auto iter = std::find(kernels.begin(), kernels.end(), outer_in_kernel); + if (iter != kernels.end()) { + continue; + } + auto &outer_in_kernel_out_tensors = outer_in_kernel->out_tensors(); + for (auto in_kernel_in_tensor : in_kernel_in_tensors) { + auto outer_in_kernel_out_tensors_iter = + std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor); + if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { + input_tensors.emplace_back(in_kernel_in_tensor); + } } } } @@ -200,17 +213,26 @@ std::vector LiteKernelUtil::SubgraphInputTensors(const std::vect std::vector LiteKernelUtil::SubgraphOutputTensors(const std::vector &kernels) { std::vector output_tensors; - std::vector all_input_tensors; - for (const auto &kernel : kernels) { - auto kernel_in_tensors = kernel->in_tensors(); - all_input_tensors.insert(all_input_tensors.end(), kernel_in_tensors.begin(), kernel_in_tensors.end()); - } std::vector output_kernels = SubgraphOutputKernels(kernels); - for (const auto &kernel : output_kernels) { - for (const auto &tensor : kernel->out_tensors()) { - auto iter = std::find(all_input_tensors.begin(), all_input_tensors.end(), tensor); - if (iter == all_input_tensors.end()) { - output_tensors.emplace_back(tensor); + for (const auto &output_kernel : output_kernels) { + auto &outer_out_kernels = output_kernel->out_kernels(); + auto &out_kernel_out_tensors = output_kernel->out_tensors(); + if (outer_out_kernels.empty()) { + output_tensors.insert(output_tensors.end(), out_kernel_out_tensors.begin(), out_kernel_out_tensors.end()); + continue; + } + for (auto outer_out_kernel : outer_out_kernels) { + auto iter = std::find(kernels.begin(), kernels.end(), outer_out_kernel); + if (iter != kernels.end()) { + continue; + } + auto &outer_out_kernel_in_tensors = outer_out_kernel->in_tensors(); + for (auto out_kernel_out_tensor : out_kernel_out_tensors) { + auto outer_out_kernel_in_tensors_iter = + std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); + if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { + output_tensors.emplace_back(out_kernel_out_tensor); + } } } } diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index fb05aadc441..64571c263d2 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -120,9 +120,9 @@ class LiteKernel { void set_out_tensors(const std::vector &out_tensors) { this->out_tensors_ = out_tensors; } - std::vector in_tensors() const { return this->in_tensors_; } + const std::vector &in_tensors() const { return this->in_tensors_; } - std::vector out_tensors() const { return this->out_tensors_; } + const std::vector &out_tensors() const { return this->out_tensors_; } void AddInKernel(LiteKernel *kernel) { if (!lite::IsContain(this->in_kernels_, kernel)) { @@ -140,9 +140,9 @@ class LiteKernel { void SetOutKernel(const std::vector &kernel) { this->out_kernels_ = kernel; } - std::vector in_kernels() const { return this->in_kernels_; } + const std::vector &in_kernels() const { return this->in_kernels_; } - std::vector out_kernels() const { return this->out_kernels_; } + const std::vector &out_kernels() const { return this->out_kernels_; } void InitOutTensorRefCount();