modify default thread num

This commit is contained in:
yefeng 2022-06-23 17:11:30 +08:00
parent 6ac3c935b8
commit 817c40acc0
7 changed files with 47 additions and 16 deletions

View File

@ -61,6 +61,7 @@ public class ModelTest {
Graph g = new Graph();
assertTrue(g.load("../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_train.ms"));
MSContext context = new MSContext();
context.init(1, 0);
TrainCfg cfg = new TrainCfg();
Model liteModel = new Model();
boolean isSuccess = liteModel.build(g, context, cfg);
@ -74,7 +75,7 @@ public class ModelTest {
Graph g = new Graph();
assertTrue(g.load(modelFile));
MSContext context = new MSContext();
context.init();
context.init(1,0);
context.addDeviceInfo(DeviceType.DT_CPU, false, 0);
Model liteModel = new Model();
boolean isSuccess = liteModel.build(g, context, null);
@ -119,6 +120,7 @@ public class ModelTest {
public void testBuildByFileFailed() {
String modelFile = "../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_tod_infer.ms";
MSContext context = new MSContext();
context.init(1, 0);
Model liteModel = new Model();
boolean isSuccess = liteModel.build(modelFile, 0, context);
assertFalse(isSuccess);

View File

@ -33,7 +33,7 @@ struct Context::Data {
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
#ifdef PARALLEL_INFERENCE
int32_t thread_num = 8;
int32_t thread_num = 0; // defaults are automatically adjusted based on computer performance
int affinity_mode_ = 1;
int32_t inter_op_parallel_num_ = 4;
#else

View File

@ -296,10 +296,19 @@ std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<R
MS_LOG(ERROR) << "user set config context nullptr.";
return nullptr;
}
if (context->GetThreadNum() <= 0) {
if (context->GetThreadNum() < 0) {
MS_LOG(ERROR) << "Invalid thread num " << context->GetThreadNum();
return nullptr;
}
if (context->GetThreadNum() == 0) {
// Defaults are automatically adjusted based on computer performance
auto thread_num = GetDefaultThreadNum();
if (thread_num == 0) {
MS_LOG(ERROR) << "computer thread num failed.";
return nullptr;
}
context->SetThreadNum(thread_num);
}
if (!context->GetThreadAffinityCoreList().empty()) {
MS_LOG(ERROR) << "parallel predict not support user set core list.";
return nullptr;

View File

@ -116,51 +116,51 @@ def test_context_04():
def test_context_05():
with pytest.raises(TypeError) as raise_info:
context = mslite.Context(thread_affinity_mode="1")
context = mslite.Context(thread_num=2, thread_affinity_mode="1")
assert "thread_affinity_mode must be int" in str(raise_info.value)
def test_context_06():
context = mslite.Context(thread_affinity_mode=2)
context = mslite.Context(thread_num=2, thread_affinity_mode=2)
assert "thread_affinity_mode: 2" in str(context)
def test_context_07():
with pytest.raises(TypeError) as raise_info:
context = mslite.Context(thread_affinity_core_list=2)
context = mslite.Context(thread_num=2, thread_affinity_core_list=2)
assert "thread_affinity_core_list must be list" in str(raise_info.value)
def test_context_08():
context = mslite.Context(thread_affinity_core_list=[2])
context = mslite.Context(thread_num=2, thread_affinity_core_list=[2])
assert "thread_affinity_core_list: [2]" in str(context)
def test_context_09():
with pytest.raises(TypeError) as raise_info:
context = mslite.Context(thread_affinity_core_list=["1", "0"])
context = mslite.Context(thread_num=2, thread_affinity_core_list=["1", "0"])
assert "thread_affinity_core_list element must be int" in str(raise_info.value)
def test_context_10():
context = mslite.Context(thread_affinity_core_list=[1, 0])
context = mslite.Context(thread_num=2, thread_affinity_core_list=[1, 0])
assert "thread_affinity_core_list: [1, 0]" in str(context)
def test_context_11():
with pytest.raises(TypeError) as raise_info:
context = mslite.Context(enable_parallel=1)
context = mslite.Context(thread_num=2, enable_parallel=1)
assert "enable_parallel must be bool" in str(raise_info.value)
def test_context_12():
context = mslite.Context(enable_parallel=True)
context = mslite.Context(thread_num=2, enable_parallel=True)
assert "enable_parallel: True" in str(context)
def test_context_13():
with pytest.raises(TypeError) as raise_info:
context = mslite.Context()
context = mslite.Context(thread_num=2)
context.append_device_info("CPUDeviceInfo")
assert "device_info must be DeviceInfo" in str(raise_info.value)
@ -168,7 +168,7 @@ def test_context_13():
def test_context_14():
gpu_device_info = mslite.GPUDeviceInfo()
cpu_device_info = mslite.CPUDeviceInfo()
context = mslite.Context()
context = mslite.Context(thread_num=2)
context.append_device_info(gpu_device_info)
context.append_device_info(cpu_device_info)
assert "device_list: 1, 0" in str(context)
@ -293,7 +293,7 @@ def test_model_01():
def test_model_build_01():
with pytest.raises(TypeError) as raise_info:
context = mslite.Context()
context = mslite.Context(thread_num=2)
model = mslite.Model()
model.build_from_file(model_path=1, model_type=mslite.ModelType.MINDIR_LITE, context=context)
assert "model_path must be str" in str(raise_info.value)
@ -317,7 +317,7 @@ def test_model_build_03():
def test_model_build_04():
with pytest.raises(RuntimeError) as raise_info:
context = mslite.Context()
context = mslite.Context(thread_num=2)
model = mslite.Model()
model.build_from_file(model_path="test.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context)
assert "model_path does not exist" in str(raise_info.value)
@ -325,7 +325,7 @@ def test_model_build_04():
def get_model():
cpu_device_info = mslite.CPUDeviceInfo()
context = mslite.Context()
context = mslite.Context(thread_num=2)
context.append_device_info(cpu_device_info)
model = mslite.Model()
model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context)

View File

@ -129,6 +129,7 @@ TEST_F(ModelParallelRunnerTest, PredictWithoutInput) {
auto context = std::make_shared<Context>();
ASSERT_NE(nullptr, context);
context->SetThreadNum(2);
auto &device_list = context->MutableDeviceInfo();
auto device_info = std::make_shared<mindspore::CPUDeviceInfo>();
ASSERT_NE(nullptr, device_info);

View File

@ -21,6 +21,9 @@
#include "include/api/metrics/accuracy.h"
namespace mindspore {
namespace {
constexpr int32_t kNumThreads = 2;
}
class TestCxxApiLiteModel : public mindspore::CommonTest {
public:
TestCxxApiLiteModel() = default;
@ -40,6 +43,7 @@ TEST_F(TestCxxApiLiteModel, test_build_graph_uninitialized_FAILED) {
Model model;
GraphCell graph_cell;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
context->MutableDeviceInfo().push_back(cpu_context);
@ -50,6 +54,7 @@ TEST_F(TestCxxApiLiteModel, test_build_SUCCES) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
context->MutableDeviceInfo().push_back(cpu_context);
@ -66,6 +71,7 @@ TEST_F(TestCxxApiLiteModel, test_train_mode_SUCCES) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
context->MutableDeviceInfo().push_back(cpu_context);
@ -85,6 +91,7 @@ TEST_F(TestCxxApiLiteModel, test_outputs_SUCCESS) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
context->MutableDeviceInfo().push_back(cpu_context);
@ -106,6 +113,7 @@ TEST_F(TestCxxApiLiteModel, test_metrics_SUCCESS) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
context->MutableDeviceInfo().push_back(cpu_context);
@ -121,6 +129,7 @@ TEST_F(TestCxxApiLiteModel, test_getparams_SUCCESS) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
context->MutableDeviceInfo().push_back(cpu_context);
auto train_cfg = std::make_shared<TrainCfg>();
@ -157,6 +166,7 @@ TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
context->MutableDeviceInfo().push_back(cpu_context);
auto train_cfg = std::make_shared<TrainCfg>();
@ -188,6 +198,7 @@ TEST_F(TestCxxApiLiteModel, test_fp32_SUCCESS) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
cpu_context->SetEnableFP16(true);
context->MutableDeviceInfo().push_back(cpu_context);
@ -201,6 +212,7 @@ TEST_F(TestCxxApiLiteModel, test_fp16_SUCCESS) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
cpu_context->SetEnableFP16(true);
context->MutableDeviceInfo().push_back(cpu_context);
@ -216,6 +228,7 @@ TEST_F(TestCxxApiLiteModel, set_weights_FAILURE) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
cpu_context->SetEnableFP16(true);
context->MutableDeviceInfo().push_back(cpu_context);
@ -239,6 +252,7 @@ TEST_F(TestCxxApiLiteModel, set_get_lr_SUCCESS) {
Graph graph;
float learn_rate = 0.2;
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
cpu_context->SetEnableFP16(true);
context->MutableDeviceInfo().push_back(cpu_context);

View File

@ -18,6 +18,9 @@
#include "include/api/context.h"
namespace mindspore {
namespace {
constexpr int32_t kNumThreads = 2;
}
class TestCxxApiContext : public UT::Common {
public:
TestCxxApiContext() = default;
@ -59,6 +62,7 @@ TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) {
TEST_F(TestCxxApiContext, test_context_cpu_context_SUCCESS) {
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
std::shared_ptr<CPUDeviceInfo> cpu = std::make_shared<CPUDeviceInfo>();
cpu->SetEnableFP16(true);
context->MutableDeviceInfo().push_back(cpu);
@ -81,6 +85,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
std::string option_9_ans = "1,2,3,4,5";
auto context = std::make_shared<Context>();
context->SetThreadNum(kNumThreads);
std::shared_ptr<AscendDeviceInfo> ascend310 = std::make_shared<AscendDeviceInfo>();
ascend310->SetInputShape(option_1);
ascend310->SetInsertOpConfigPath(option_2);