forked from mindspore-Ecosystem/mindspore
!7864 fix bug in func SubgraphInputTensors and SubgraphOutputTensors
Merge pull request !7864 from hangq/primitive
This commit is contained in:
commit
7bc5395e23
|
@ -181,17 +181,30 @@ std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphOutputKernels(
|
|||
|
||||
std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vector<kernel::LiteKernel *> &kernels) {
|
||||
std::vector<lite::Tensor *> input_tensors;
|
||||
std::vector<lite::Tensor *> all_output_tensors;
|
||||
for (const auto &kernel : kernels) {
|
||||
auto kernel_out_tensors = kernel->out_tensors();
|
||||
all_output_tensors.insert(all_output_tensors.end(), kernel_out_tensors.begin(), kernel_out_tensors.end());
|
||||
}
|
||||
std::vector<kernel::LiteKernel *> input_kernels = SubgraphInputKernels(kernels);
|
||||
for (const auto &kernel : input_kernels) {
|
||||
for (const auto &tensor : kernel->in_tensors()) {
|
||||
auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor);
|
||||
if (iter == all_output_tensors.end() && !tensor->IsConst()) {
|
||||
input_tensors.emplace_back(tensor);
|
||||
for (const auto &input_kernel : input_kernels) {
|
||||
auto &outer_in_kernels = input_kernel->in_kernels();
|
||||
auto &in_kernel_in_tensors = input_kernel->in_tensors();
|
||||
if (outer_in_kernels.empty()) {
|
||||
for (auto &in_kernel_in_tensor : in_kernel_in_tensors) {
|
||||
if (!in_kernel_in_tensor->IsConst()) {
|
||||
input_tensors.push_back(in_kernel_in_tensor);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
for (auto outer_in_kernel : outer_in_kernels) {
|
||||
auto iter = std::find(kernels.begin(), kernels.end(), outer_in_kernel);
|
||||
if (iter != kernels.end()) {
|
||||
continue;
|
||||
}
|
||||
auto &outer_in_kernel_out_tensors = outer_in_kernel->out_tensors();
|
||||
for (auto in_kernel_in_tensor : in_kernel_in_tensors) {
|
||||
auto outer_in_kernel_out_tensors_iter =
|
||||
std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor);
|
||||
if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
|
||||
input_tensors.emplace_back(in_kernel_in_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -200,17 +213,26 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
|
|||
|
||||
std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vector<kernel::LiteKernel *> &kernels) {
|
||||
std::vector<lite::Tensor *> output_tensors;
|
||||
std::vector<lite::Tensor *> all_input_tensors;
|
||||
for (const auto &kernel : kernels) {
|
||||
auto kernel_in_tensors = kernel->in_tensors();
|
||||
all_input_tensors.insert(all_input_tensors.end(), kernel_in_tensors.begin(), kernel_in_tensors.end());
|
||||
}
|
||||
std::vector<kernel::LiteKernel *> output_kernels = SubgraphOutputKernels(kernels);
|
||||
for (const auto &kernel : output_kernels) {
|
||||
for (const auto &tensor : kernel->out_tensors()) {
|
||||
auto iter = std::find(all_input_tensors.begin(), all_input_tensors.end(), tensor);
|
||||
if (iter == all_input_tensors.end()) {
|
||||
output_tensors.emplace_back(tensor);
|
||||
for (const auto &output_kernel : output_kernels) {
|
||||
auto &outer_out_kernels = output_kernel->out_kernels();
|
||||
auto &out_kernel_out_tensors = output_kernel->out_tensors();
|
||||
if (outer_out_kernels.empty()) {
|
||||
output_tensors.insert(output_tensors.end(), out_kernel_out_tensors.begin(), out_kernel_out_tensors.end());
|
||||
continue;
|
||||
}
|
||||
for (auto outer_out_kernel : outer_out_kernels) {
|
||||
auto iter = std::find(kernels.begin(), kernels.end(), outer_out_kernel);
|
||||
if (iter != kernels.end()) {
|
||||
continue;
|
||||
}
|
||||
auto &outer_out_kernel_in_tensors = outer_out_kernel->in_tensors();
|
||||
for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
|
||||
auto outer_out_kernel_in_tensors_iter =
|
||||
std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
|
||||
if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) {
|
||||
output_tensors.emplace_back(out_kernel_out_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -120,9 +120,9 @@ class LiteKernel {
|
|||
|
||||
void set_out_tensors(const std::vector<lite::Tensor *> &out_tensors) { this->out_tensors_ = out_tensors; }
|
||||
|
||||
std::vector<lite::Tensor *> in_tensors() const { return this->in_tensors_; }
|
||||
const std::vector<lite::Tensor *> &in_tensors() const { return this->in_tensors_; }
|
||||
|
||||
std::vector<lite::Tensor *> out_tensors() const { return this->out_tensors_; }
|
||||
const std::vector<lite::Tensor *> &out_tensors() const { return this->out_tensors_; }
|
||||
|
||||
void AddInKernel(LiteKernel *kernel) {
|
||||
if (!lite::IsContain(this->in_kernels_, kernel)) {
|
||||
|
@ -140,9 +140,9 @@ class LiteKernel {
|
|||
|
||||
void SetOutKernel(const std::vector<LiteKernel *> &kernel) { this->out_kernels_ = kernel; }
|
||||
|
||||
std::vector<LiteKernel *> in_kernels() const { return this->in_kernels_; }
|
||||
const std::vector<LiteKernel *> &in_kernels() const { return this->in_kernels_; }
|
||||
|
||||
std::vector<LiteKernel *> out_kernels() const { return this->out_kernels_; }
|
||||
const std::vector<LiteKernel *> &out_kernels() const { return this->out_kernels_; }
|
||||
|
||||
void InitOutTensorRefCount();
|
||||
|
||||
|
|
Loading…
Reference in New Issue