!7864 fix bug in func SubgraphInputTensors and SubgraphOutputTensors

Merge pull request !7864 from hangq/primitive
This commit is contained in:
mindspore-ci-bot 2020-10-27 23:16:20 +08:00 committed by Gitee
commit 7bc5395e23
2 changed files with 46 additions and 24 deletions

View File

@ -181,17 +181,30 @@ std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphOutputKernels(
std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vector<kernel::LiteKernel *> &kernels) { std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<lite::Tensor *> input_tensors; std::vector<lite::Tensor *> input_tensors;
std::vector<lite::Tensor *> all_output_tensors;
for (const auto &kernel : kernels) {
auto kernel_out_tensors = kernel->out_tensors();
all_output_tensors.insert(all_output_tensors.end(), kernel_out_tensors.begin(), kernel_out_tensors.end());
}
std::vector<kernel::LiteKernel *> input_kernels = SubgraphInputKernels(kernels); std::vector<kernel::LiteKernel *> input_kernels = SubgraphInputKernels(kernels);
for (const auto &kernel : input_kernels) { for (const auto &input_kernel : input_kernels) {
for (const auto &tensor : kernel->in_tensors()) { auto &outer_in_kernels = input_kernel->in_kernels();
auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor); auto &in_kernel_in_tensors = input_kernel->in_tensors();
if (iter == all_output_tensors.end() && !tensor->IsConst()) { if (outer_in_kernels.empty()) {
input_tensors.emplace_back(tensor); for (auto &in_kernel_in_tensor : in_kernel_in_tensors) {
if (!in_kernel_in_tensor->IsConst()) {
input_tensors.push_back(in_kernel_in_tensor);
}
}
continue;
}
for (auto outer_in_kernel : outer_in_kernels) {
auto iter = std::find(kernels.begin(), kernels.end(), outer_in_kernel);
if (iter != kernels.end()) {
continue;
}
auto &outer_in_kernel_out_tensors = outer_in_kernel->out_tensors();
for (auto in_kernel_in_tensor : in_kernel_in_tensors) {
auto outer_in_kernel_out_tensors_iter =
std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor);
if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
input_tensors.emplace_back(in_kernel_in_tensor);
}
} }
} }
} }
@ -200,17 +213,26 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vector<kernel::LiteKernel *> &kernels) { std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<lite::Tensor *> output_tensors; std::vector<lite::Tensor *> output_tensors;
std::vector<lite::Tensor *> all_input_tensors;
for (const auto &kernel : kernels) {
auto kernel_in_tensors = kernel->in_tensors();
all_input_tensors.insert(all_input_tensors.end(), kernel_in_tensors.begin(), kernel_in_tensors.end());
}
std::vector<kernel::LiteKernel *> output_kernels = SubgraphOutputKernels(kernels); std::vector<kernel::LiteKernel *> output_kernels = SubgraphOutputKernels(kernels);
for (const auto &kernel : output_kernels) { for (const auto &output_kernel : output_kernels) {
for (const auto &tensor : kernel->out_tensors()) { auto &outer_out_kernels = output_kernel->out_kernels();
auto iter = std::find(all_input_tensors.begin(), all_input_tensors.end(), tensor); auto &out_kernel_out_tensors = output_kernel->out_tensors();
if (iter == all_input_tensors.end()) { if (outer_out_kernels.empty()) {
output_tensors.emplace_back(tensor); output_tensors.insert(output_tensors.end(), out_kernel_out_tensors.begin(), out_kernel_out_tensors.end());
continue;
}
for (auto outer_out_kernel : outer_out_kernels) {
auto iter = std::find(kernels.begin(), kernels.end(), outer_out_kernel);
if (iter != kernels.end()) {
continue;
}
auto &outer_out_kernel_in_tensors = outer_out_kernel->in_tensors();
for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
auto outer_out_kernel_in_tensors_iter =
std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) {
output_tensors.emplace_back(out_kernel_out_tensor);
}
} }
} }
} }

View File

@ -120,9 +120,9 @@ class LiteKernel {
void set_out_tensors(const std::vector<lite::Tensor *> &out_tensors) { this->out_tensors_ = out_tensors; } void set_out_tensors(const std::vector<lite::Tensor *> &out_tensors) { this->out_tensors_ = out_tensors; }
std::vector<lite::Tensor *> in_tensors() const { return this->in_tensors_; } const std::vector<lite::Tensor *> &in_tensors() const { return this->in_tensors_; }
std::vector<lite::Tensor *> out_tensors() const { return this->out_tensors_; } const std::vector<lite::Tensor *> &out_tensors() const { return this->out_tensors_; }
void AddInKernel(LiteKernel *kernel) { void AddInKernel(LiteKernel *kernel) {
if (!lite::IsContain(this->in_kernels_, kernel)) { if (!lite::IsContain(this->in_kernels_, kernel)) {
@ -140,9 +140,9 @@ class LiteKernel {
void SetOutKernel(const std::vector<LiteKernel *> &kernel) { this->out_kernels_ = kernel; } void SetOutKernel(const std::vector<LiteKernel *> &kernel) { this->out_kernels_ = kernel; }
std::vector<LiteKernel *> in_kernels() const { return this->in_kernels_; } const std::vector<LiteKernel *> &in_kernels() const { return this->in_kernels_; }
std::vector<LiteKernel *> out_kernels() const { return this->out_kernels_; } const std::vector<LiteKernel *> &out_kernels() const { return this->out_kernels_; }
void InitOutTensorRefCount(); void InitOutTensorRefCount();