forked from mindspore-Ecosystem/mindspore
fix subgraph inputs and add ut
This commit is contained in:
parent
6c4b4f91d2
commit
1b6e5facd4
|
@ -42,11 +42,16 @@ int LiteKernel::DecOutTensorRefCount() {
|
|||
std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels(
|
||||
const std::vector<kernel::LiteKernel *> &kernels) {
|
||||
std::vector<kernel::LiteKernel *> input_kernels;
|
||||
for (const auto kernel : kernels) {
|
||||
for (auto input : kernel->in_kernels()) {
|
||||
for (const auto &kernel : kernels) {
|
||||
if (kernel->in_kernels().empty() && !kernel->in_tensors().empty()) {
|
||||
input_kernels.emplace_back(kernel);
|
||||
continue;
|
||||
}
|
||||
for (const auto &input : kernel->in_kernels()) {
|
||||
auto iter = std::find(kernels.begin(), kernels.end(), input);
|
||||
if (iter == kernels.end()) {
|
||||
input_kernels.emplace_back(input);
|
||||
auto item = std::find(input_kernels.begin(), input_kernels.end(), kernel);
|
||||
if (iter == kernels.end() && item == input_kernels.end()) {
|
||||
input_kernels.emplace_back(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -56,11 +61,16 @@ std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels(
|
|||
std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphOutputKernels(
|
||||
const std::vector<kernel::LiteKernel *> &kernels) {
|
||||
std::vector<kernel::LiteKernel *> output_kernels;
|
||||
for (const auto kernel : kernels) {
|
||||
for (const auto output : kernel->out_kernels()) {
|
||||
for (const auto &kernel : kernels) {
|
||||
if (kernel->out_kernels().empty() && !kernel->out_tensors().empty()) {
|
||||
output_kernels.emplace_back(kernel);
|
||||
continue;
|
||||
}
|
||||
for (const auto &output : kernel->out_kernels()) {
|
||||
auto iter = std::find(kernels.begin(), kernels.end(), output);
|
||||
if (iter == kernels.end()) {
|
||||
output_kernels.emplace_back(output);
|
||||
auto item = std::find(output_kernels.begin(), output_kernels.end(), kernel);
|
||||
if (iter == kernels.end() && item == output_kernels.end()) {
|
||||
output_kernels.emplace_back(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -289,6 +289,7 @@ set(TEST_SRC
|
|||
${TEST_DIR}/main.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/common/pack_tests.cc
|
||||
${TEST_DIR}/ut/src/infer_test.cc
|
||||
${TEST_DIR}/ut/src/utils_test.cc
|
||||
)
|
||||
|
||||
if (SUPPORT_TRAIN)
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include "mindspore/lite/schema/inner/model_generated.h"
|
||||
#include "mindspore/lite/include/model.h"
|
||||
#include "common/common_test.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "mindspore/lite/src/lite_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
class UtilsTest : public mindspore::CommonTest {
|
||||
public:
|
||||
UtilsTest() {}
|
||||
};
|
||||
|
||||
TEST_F(UtilsTest, TestSubgraph) {
|
||||
auto kernel0 = std::make_shared<kernel::LiteKernel>();
|
||||
auto kernel1 = std::make_shared<kernel::LiteKernel>();
|
||||
auto kernel2 = std::make_shared<kernel::LiteKernel>();
|
||||
|
||||
auto tensor0 = std::make_shared<lite::tensor::Tensor>();
|
||||
auto tensor1 = std::make_shared<lite::tensor::Tensor>();
|
||||
auto tensor2 = std::make_shared<lite::tensor::Tensor>();
|
||||
auto tensor3 = std::make_shared<lite::tensor::Tensor>();
|
||||
auto tensor4 = std::make_shared<lite::tensor::Tensor>();
|
||||
|
||||
kernel0->AddOutKernel(kernel1.get());
|
||||
kernel1->AddInKernel(kernel0.get());
|
||||
kernel1->AddOutKernel(kernel2.get());
|
||||
kernel2->AddInKernel(kernel1.get());
|
||||
|
||||
kernel0->set_in_tensors({tensor0.get(), tensor1.get()});
|
||||
kernel0->set_out_tensors({tensor2.get()});
|
||||
kernel1->set_in_tensors({tensor2.get()});
|
||||
kernel1->set_out_tensors({tensor3.get()});
|
||||
kernel2->set_in_tensors({tensor3.get()});
|
||||
kernel2->set_out_tensors({tensor4.get()});
|
||||
|
||||
std::vector<kernel::LiteKernel *> kernels = {kernel0.get(), kernel1.get(), kernel2.get()};
|
||||
|
||||
auto input_kernels = kernel::LiteKernelUtil::SubgraphInputKernels(kernels);
|
||||
ASSERT_EQ(input_kernels.size(), 1);
|
||||
auto output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels);
|
||||
ASSERT_EQ(output_kernels.size(), 1);
|
||||
auto input_tensors = kernel::LiteKernelUtil::SubgraphInputTensors(kernels);
|
||||
ASSERT_EQ(input_tensors.size(), 2);
|
||||
auto output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels);
|
||||
ASSERT_EQ(output_tensors.size(), 1);
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue