mindrt parallel ut test

This commit is contained in:
ling 2021-06-28 17:45:00 +08:00
parent faaf923972
commit 904e56a757
8 changed files with 188 additions and 16 deletions

View File

@ -229,24 +229,22 @@ void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
MS_LOG(ERROR) << "New sub graph failed!";
return;
}
new_sub_graph->name_ = "subgraph-split-" + std::to_string(new_sub_index);
new_sub_graph->name_ = "SubSplit" + std::to_string(new_sub_index);
Model::Node *new_partial_node = new (std::nothrow) Model::Node();
if (new_partial_node == nullptr) {
MS_LOG(ERROR) << "New partial node failed!";
delete new_sub_graph;
return;
}
new_partial_node->name_ = "Partial-subgraph-split-" + std::to_string(new_sub_index);
new_partial_node->name_ = "SubSplitPartial" + std::to_string(new_sub_index);
if (device_type == DT_CPU) {
new_partial_node->name_ = "cpu_" + new_partial_node->name_;
new_partial_node->name_ = "Cpu" + new_partial_node->name_;
} else if (device_type == DT_GPU) {
new_partial_node->name_ = "gpu_" + new_partial_node->name_;
new_partial_node->name_ = "Gpu" + new_partial_node->name_;
} else if (device_type == DT_NPU) {
new_partial_node->name_ = "npu_" + new_partial_node->name_;
} else {
new_partial_node->name_ = "unknow_" + new_partial_node->name_;
new_partial_node->name_ = "Npu" + new_partial_node->name_;
}
new_partial_node->node_type_ = mindspore::lite::NodeType_ValueNode;
new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index);

View File

@ -68,3 +68,6 @@ echo 'run train ut tests'
## ./lite-test --gtest_filter="NetworkTest.lenetnet"
echo 'run inference ut tests'
./lite-test --gtest_filter="ControlFlowTest.TestMergeWhileModel"
echo 'run mindrt parallel ut test'
./lite-test --gtest_filter="MindrtParallelTest.offline1"

View File

@ -0,0 +1,168 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gtest/gtest.h"
#include "common/common_test.h"
#include "include/errorcode.h"
#include "tools/converter/converter.h"
#include "tools/benchmark/benchmark.h"
#include "src/mindrt_executor.h"
#include "src/lite_session.h"
#include "src/lite_kernel.h"
namespace mindspore {
class MindrtParallelTest : public mindspore::CommonTest {
public:
MindrtParallelTest() {}
};
int CheckOffline1(session::LiteSession *session) {
/* ----------- start check -------------- */
lite::LiteSession *lite_session = reinterpret_cast<lite::LiteSession *>(session);
auto kernels = lite_session->get_kernels();
if (kernels.size() != 4) {
return -1;
}
/* sub-graph-0 */
kernel::SubGraphKernel *subgraph0 = reinterpret_cast<kernel::SubGraphKernel *>(kernels[0]);
std::vector<kernel::LiteKernel *> nodes0 = subgraph0->nodes();
if (nodes0.size() != 1) {
return -2;
}
if (nodes0[0]->type() != schema::PrimitiveType_SplitWithOverlap) {
return -3;
}
/* sub-graph-1 */
kernel::SubGraphKernel *subgraph1 = reinterpret_cast<kernel::SubGraphKernel *>(kernels[1]);
std::vector<kernel::LiteKernel *> nodes1 = subgraph1->nodes();
if (nodes1.size() != 3) {
return -4;
}
if (nodes1[0]->type() != schema::PrimitiveType_Conv2DFusion ||
nodes1[1]->type() != schema::PrimitiveType_Conv2DFusion ||
nodes1[2]->type() != schema::PrimitiveType_Conv2DFusion) {
return -5;
}
/* sub-graph-2 */
kernel::SubGraphKernel *subgraph2 = reinterpret_cast<kernel::SubGraphKernel *>(kernels[2]);
std::vector<kernel::LiteKernel *> nodes2 = subgraph2->nodes();
if (nodes2.size() != 3) {
return -6;
}
if (nodes2[0]->type() != schema::PrimitiveType_Conv2DFusion ||
nodes2[1]->type() != schema::PrimitiveType_Conv2DFusion ||
nodes2[2]->type() != schema::PrimitiveType_Conv2DFusion) {
return -7;
}
/* sub-graph-3 */
kernel::SubGraphKernel *subgraph3 = reinterpret_cast<kernel::SubGraphKernel *>(kernels[3]);
std::vector<kernel::LiteKernel *> nodes3 = subgraph3->nodes();
if (nodes3.size() != 8) {
return -8;
}
if (nodes3[0]->type() != schema::PrimitiveType_Concat) {
return -9;
}
return lite::RET_OK;
}
int CheckRuntime1(session::LiteSession *session) {
lite::LiteSession *lite_session = reinterpret_cast<lite::LiteSession *>(session);
auto kernels = lite_session->get_kernels();
if (kernels.size() != 6) {
return -1;
}
return lite::RET_OK;
}
TEST_F(MindrtParallelTest, offline1) {
const char *converter_argv[] = {"./converter", "--fmk=TFLITE",
"--modelFile=./mindrtParallel/mindrt_parallel_model.tflite",
"--outputFile=./mindrtParallel/mindrt_parallel_model_split",
"--configFile=./mindrtParallel/mindrt_parallel_model.config"};
int converter_ret = mindspore::lite::RunConverter(5, converter_argv);
ASSERT_EQ(converter_ret, lite::RET_OK);
size_t size = 0;
char *graph_buf = lite::ReadFile("./mindrtParallel/mindrt_parallel_model_split.ms", &size);
ASSERT_NE(graph_buf, nullptr);
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(graph_buf, size));
delete[](graph_buf);
ASSERT_NE(model, nullptr);
auto context = std::make_shared<lite::Context>();
ASSERT_NE(context, nullptr);
context->enable_parallel_ = true;
session::LiteSession *session = session::LiteSession::CreateSession(context.get());
ASSERT_NE(session, nullptr);
int benchmark_ret = session->CompileGraph(model.get());
ASSERT_EQ(benchmark_ret, lite::RET_OK);
ASSERT_EQ(CheckOffline1(session), lite::RET_OK);
auto inputs = session->GetInputs();
for (auto in : inputs) {
in->MutableData();
}
benchmark_ret = session->RunGraph(nullptr, nullptr);
ASSERT_EQ(benchmark_ret, lite::RET_OK);
}
TEST_F(MindrtParallelTest, runtime1) {
const char *converter_argv[] = {"./converter", "--fmk=TFLITE",
"--modelFile=./mindrtParallel/mindrt_parallel_model.tflite",
"--outputFile=./mindrtParallel/mindrt_parallel_model"};
int converter_ret = mindspore::lite::RunConverter(4, converter_argv);
ASSERT_EQ(converter_ret, lite::RET_OK);
size_t size = 0;
char *graph_buf = lite::ReadFile("./mindrtParallel/mindrt_parallel_model.ms", &size);
ASSERT_NE(graph_buf, nullptr);
auto model = std::shared_ptr<lite::Model>(lite::Model::Import(graph_buf, size));
delete[](graph_buf);
ASSERT_NE(model, nullptr);
auto context = std::make_shared<lite::Context>();
ASSERT_NE(context, nullptr);
context->enable_parallel_ = true;
session::LiteSession *session = session::LiteSession::CreateSession(context.get());
ASSERT_NE(session, nullptr);
int benchmark_ret = session->CompileGraph(model.get());
ASSERT_EQ(benchmark_ret, lite::RET_OK);
ASSERT_EQ(CheckRuntime1(session), lite::RET_OK);
auto inputs = session->GetInputs();
for (auto in : inputs) {
in->MutableData();
}
benchmark_ret = session->RunGraph(nullptr, nullptr);
ASSERT_EQ(benchmark_ret, lite::RET_OK);
}
} // namespace mindspore

View File

@ -0,0 +1,3 @@
device0=cpu
device1=cpu
computeRate=device0:1;device1:2;

File diff suppressed because one or more lines are too long

View File

@ -219,11 +219,8 @@ int Flags::InitConfigFile() {
return RET_INPUT_PARAM_INVALID;
}
}
if (parallel_split_config_.parallel_split_type_ != SplitNo &&
!CheckOfflineParallelConfig(this->configFile, &parallel_split_config_)) {
std::cerr << "offline kernel parallel split config set error." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
(void)CheckOfflineParallelConfig(this->configFile, &parallel_split_config_);
return RET_OK;
}
@ -319,17 +316,14 @@ bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *pa
std::vector<std::string> config_devices = {"cpu", "gpu", "npu"};
auto compute_rate_result = GetStrFromConfigFile(file, kComputeRate);
if (compute_rate_result.empty()) {
std::cerr << "config setting error: compute rate should be set." << std::endl;
return false;
}
std::string device0_result = GetStrFromConfigFile(file, kSplitDevice0);
if (device0_result.empty()) {
std::cerr << "config setting error: device0 should be set." << std::endl;
return false;
}
std::string device1_result = GetStrFromConfigFile(file, kSplitDevice1);
if (device1_result.empty()) {
std::cerr << "config setting error: device1 should be set." << std::endl;
return false;
}
bool device0_flag = false;