enable mindrt st test

This commit is contained in:
ling 2021-07-01 09:20:31 +08:00
parent af8c15265d
commit 797b65f97c
5 changed files with 56 additions and 13 deletions

View File

@ -72,6 +72,8 @@ class LiteSession : public session::LiteSession {
void set_model(Model *model) { this->model_ = model; }
const std::vector<kernel::LiteKernel *> &get_kernels() const { return this->kernels_; }
protected:
static void ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lite::Tensor *dst_tensor);

View File

@ -74,9 +74,15 @@ int SplitWithOverlapBaseCPUKernel::ReSize() {
param_->inner_stride_ = 1;
for (int i = 0; i < static_cast<int>(input_shape.size()); i++) {
if (i < param_->split_dim_) param_->outer_total_dim_ *= input_shape[i];
if (i == param_->split_dim_) param_->split_dim_size_ = input_shape[param_->split_dim_];
if (i > param_->split_dim_) param_->inner_stride_ *= input_shape[i];
if (i < param_->split_dim_) {
param_->outer_total_dim_ *= input_shape[i];
}
if (i == param_->split_dim_) {
param_->split_dim_size_ = input_shape[param_->split_dim_];
}
if (i > param_->split_dim_) {
param_->inner_stride_ *= input_shape[i];
}
}
thread_count_ = MSMIN(param_->num_split_, op_parameter_->thread_num_);

View File

@ -111,13 +111,11 @@ set(TEST_LITE_SRC
file(GLOB KERNEL_REG_SRC ${LITE_DIR}/src/registry/*.cc)
set(TEST_LITE_SRC ${TEST_LITE_SRC} ${KERNEL_REG_SRC})
if(ENABLE_TOOLS)
set(TEST_LITE_SRC
${TEST_LITE_SRC}
${LITE_DIR}/tools/benchmark/benchmark.cc
${LITE_DIR}/test/st/benchmark_test.cc
)
endif()
set(TEST_LITE_SRC
${TEST_LITE_SRC}
${LITE_DIR}/tools/benchmark/benchmark.cc
${LITE_DIR}/test/st/benchmark_test.cc
)
### gpu runtime
if(MSLITE_GPU_BACKEND STREQUAL opencl)
@ -316,6 +314,7 @@ if(MSLITE_ENABLE_CONVERTER)
${TEST_SRC}
${TEST_DIR}/st/converter_test.cc
${TEST_DIR}/st/control_flow_test.cc
${TEST_DIR}/st/mindrt_parallel_test.cc
${TEST_DIR}/st/sub_graph_test.cc
${TEST_DIR}/common/import_from_meta_graphT.cc
${TEST_DIR}/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc

2
mindspore/lite/test/runtest.sh Executable file → Normal file
View File

@ -70,4 +70,4 @@ echo 'run inference ut tests'
./lite-test --gtest_filter="ControlFlowTest.TestMergeWhileModel"
echo 'run mindrt parallel ut test'
./lite-test --gtest_filter="MindrtParallelTest.offline1"
./lite-test --gtest_filter="MindrtParallelTest.*"

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -128,6 +128,41 @@ TEST_F(MindrtParallelTest, offline1) {
benchmark_ret = session->RunGraph(nullptr, nullptr);
ASSERT_EQ(benchmark_ret, lite::RET_OK);
delete session;
}
TEST_F(MindrtParallelTest, offline2) {
const char *benchmark_argv1[] = {"./benchmark",
"--enableParallel=true",
"--numThreads=3",
"--modelFile=./mindrtParallel/mindrt_parallel_model_split.ms",
"--inDataFile=./mindrtParallel/mindrt_parallel_model.bin",
"--benchmarkDataFile=./mindrtParallel/mindrt_parallel_model.out"};
int converter_ret = mindspore::lite::RunBenchmark(6, benchmark_argv1);
ASSERT_EQ(converter_ret, lite::RET_OK);
}
TEST_F(MindrtParallelTest, offline3) {
const char *benchmark_argv2[] = {"./benchmark",
"--enableParallel=true",
"--numThreads=4",
"--modelFile=./mindrtParallel/mindrt_parallel_model_split.ms",
"--inDataFile=./mindrtParallel/mindrt_parallel_model.bin",
"--benchmarkDataFile=./mindrtParallel/mindrt_parallel_model.out"};
int converter_ret = mindspore::lite::RunBenchmark(6, benchmark_argv2);
ASSERT_EQ(converter_ret, lite::RET_OK);
}
TEST_F(MindrtParallelTest, offline4) {
const char *benchmark_argv3[] = {"./benchmark",
"--enableParallel=false",
"--numThreads=1",
"--modelFile=./mindrtParallel/mindrt_parallel_model_split.ms",
"--inDataFile=./mindrtParallel/mindrt_parallel_model.bin",
"--benchmarkDataFile=./mindrtParallel/mindrt_parallel_model.out"};
int converter_ret = mindspore::lite::RunBenchmark(6, benchmark_argv3);
ASSERT_EQ(converter_ret, lite::RET_OK);
}
TEST_F(MindrtParallelTest, runtime1) {
@ -161,8 +196,9 @@ TEST_F(MindrtParallelTest, runtime1) {
for (auto in : inputs) {
in->MutableData();
}
benchmark_ret = session->RunGraph(nullptr, nullptr);
ASSERT_EQ(benchmark_ret, lite::RET_OK);
delete session;
}
} // namespace mindspore