diff --git a/mindspore/lite/java/src/test/java/com/mindspore/ModelTest.java b/mindspore/lite/java/src/test/java/com/mindspore/ModelTest.java index 7dbe08af118..d8545bdb6fd 100644 --- a/mindspore/lite/java/src/test/java/com/mindspore/ModelTest.java +++ b/mindspore/lite/java/src/test/java/com/mindspore/ModelTest.java @@ -61,6 +61,7 @@ public class ModelTest { Graph g = new Graph(); assertTrue(g.load("../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_train.ms")); MSContext context = new MSContext(); + context.init(1, 0); TrainCfg cfg = new TrainCfg(); Model liteModel = new Model(); boolean isSuccess = liteModel.build(g, context, cfg); @@ -74,7 +75,7 @@ public class ModelTest { Graph g = new Graph(); assertTrue(g.load(modelFile)); MSContext context = new MSContext(); - context.init(); + context.init(1,0); context.addDeviceInfo(DeviceType.DT_CPU, false, 0); Model liteModel = new Model(); boolean isSuccess = liteModel.build(g, context, null); @@ -119,6 +120,7 @@ public class ModelTest { public void testBuildByFileFailed() { String modelFile = "../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_tod_infer.ms"; MSContext context = new MSContext(); + context.init(1, 0); Model liteModel = new Model(); boolean isSuccess = liteModel.build(modelFile, 0, context); assertFalse(isSuccess); diff --git a/mindspore/lite/src/runtime/cxx_api/context.h b/mindspore/lite/src/runtime/cxx_api/context.h index 0d1914b1bca..a9a9d8aca4b 100644 --- a/mindspore/lite/src/runtime/cxx_api/context.h +++ b/mindspore/lite/src/runtime/cxx_api/context.h @@ -33,7 +33,7 @@ struct Context::Data { std::vector> device_info_list; #ifdef PARALLEL_INFERENCE - int32_t thread_num = 8; + int32_t thread_num = 0; // defaults are automatically adjusted based on computer performance int affinity_mode_ = 1; int32_t inter_op_parallel_num_ = 4; #else diff --git a/mindspore/lite/src/runtime/cxx_api/model_pool/model_pool.cc b/mindspore/lite/src/runtime/cxx_api/model_pool/model_pool.cc index 9347d9408c7..a1fea4d5619 100644 --- a/mindspore/lite/src/runtime/cxx_api/model_pool/model_pool.cc +++ b/mindspore/lite/src/runtime/cxx_api/model_pool/model_pool.cc @@ -296,10 +296,19 @@ std::shared_ptr ModelPool::GetUserDefineContext(const std::shared_ptrGetThreadNum() <= 0) { + if (context->GetThreadNum() < 0) { MS_LOG(ERROR) << "Invalid thread num " << context->GetThreadNum(); return nullptr; } + if (context->GetThreadNum() == 0) { + // Defaults are automatically adjusted based on computer performance + auto thread_num = GetDefaultThreadNum(); + if (thread_num == 0) { + MS_LOG(ERROR) << "computer thread num failed."; + return nullptr; + } + context->SetThreadNum(thread_num); + } if (!context->GetThreadAffinityCoreList().empty()) { MS_LOG(ERROR) << "parallel predict not support user set core list."; return nullptr; diff --git a/mindspore/lite/test/ut/python/test_inference_api.py b/mindspore/lite/test/ut/python/test_inference_api.py index 7e13b6e7d02..b806b34bca8 100644 --- a/mindspore/lite/test/ut/python/test_inference_api.py +++ b/mindspore/lite/test/ut/python/test_inference_api.py @@ -116,51 +116,51 @@ def test_context_04(): def test_context_05(): with pytest.raises(TypeError) as raise_info: - context = mslite.Context(thread_affinity_mode="1") + context = mslite.Context(thread_num=2, thread_affinity_mode="1") assert "thread_affinity_mode must be int" in str(raise_info.value) def test_context_06(): - context = mslite.Context(thread_affinity_mode=2) + context = mslite.Context(thread_num=2, thread_affinity_mode=2) assert "thread_affinity_mode: 2" in str(context) def test_context_07(): with pytest.raises(TypeError) as raise_info: - context = mslite.Context(thread_affinity_core_list=2) + context = mslite.Context(thread_num=2, thread_affinity_core_list=2) assert "thread_affinity_core_list must be list" in str(raise_info.value) def test_context_08(): - context = mslite.Context(thread_affinity_core_list=[2]) + context = mslite.Context(thread_num=2, thread_affinity_core_list=[2]) assert "thread_affinity_core_list: [2]" in str(context) def test_context_09(): with pytest.raises(TypeError) as raise_info: - context = mslite.Context(thread_affinity_core_list=["1", "0"]) + context = mslite.Context(thread_num=2, thread_affinity_core_list=["1", "0"]) assert "thread_affinity_core_list element must be int" in str(raise_info.value) def test_context_10(): - context = mslite.Context(thread_affinity_core_list=[1, 0]) + context = mslite.Context(thread_num=2, thread_affinity_core_list=[1, 0]) assert "thread_affinity_core_list: [1, 0]" in str(context) def test_context_11(): with pytest.raises(TypeError) as raise_info: - context = mslite.Context(enable_parallel=1) + context = mslite.Context(thread_num=2, enable_parallel=1) assert "enable_parallel must be bool" in str(raise_info.value) def test_context_12(): - context = mslite.Context(enable_parallel=True) + context = mslite.Context(thread_num=2, enable_parallel=True) assert "enable_parallel: True" in str(context) def test_context_13(): with pytest.raises(TypeError) as raise_info: - context = mslite.Context() + context = mslite.Context(thread_num=2) context.append_device_info("CPUDeviceInfo") assert "device_info must be DeviceInfo" in str(raise_info.value) @@ -168,7 +168,7 @@ def test_context_13(): def test_context_14(): gpu_device_info = mslite.GPUDeviceInfo() cpu_device_info = mslite.CPUDeviceInfo() - context = mslite.Context() + context = mslite.Context(thread_num=2) context.append_device_info(gpu_device_info) context.append_device_info(cpu_device_info) assert "device_list: 1, 0" in str(context) @@ -293,7 +293,7 @@ def test_model_01(): def test_model_build_01(): with pytest.raises(TypeError) as raise_info: - context = mslite.Context() + context = mslite.Context(thread_num=2) model = mslite.Model() model.build_from_file(model_path=1, model_type=mslite.ModelType.MINDIR_LITE, context=context) assert "model_path must be str" in str(raise_info.value) @@ -317,7 +317,7 @@ def test_model_build_03(): def test_model_build_04(): with pytest.raises(RuntimeError) as raise_info: - context = mslite.Context() + context = mslite.Context(thread_num=2) model = mslite.Model() model.build_from_file(model_path="test.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context) assert "model_path does not exist" in str(raise_info.value) @@ -325,7 +325,7 @@ def test_model_build_04(): def get_model(): cpu_device_info = mslite.CPUDeviceInfo() - context = mslite.Context() + context = mslite.Context(thread_num=2) context.append_device_info(cpu_device_info) model = mslite.Model() model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context) diff --git a/mindspore/lite/test/ut/src/api/model_parallel_runner_test.cc b/mindspore/lite/test/ut/src/api/model_parallel_runner_test.cc index 5b5fe458d01..ad36da3e6f4 100644 --- a/mindspore/lite/test/ut/src/api/model_parallel_runner_test.cc +++ b/mindspore/lite/test/ut/src/api/model_parallel_runner_test.cc @@ -129,6 +129,7 @@ TEST_F(ModelParallelRunnerTest, PredictWithoutInput) { auto context = std::make_shared(); ASSERT_NE(nullptr, context); + context->SetThreadNum(2); auto &device_list = context->MutableDeviceInfo(); auto device_info = std::make_shared(); ASSERT_NE(nullptr, device_info); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc index e18f8e117eb..b42eba5ee87 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/cxx_api/model_test.cc @@ -21,6 +21,9 @@ #include "include/api/metrics/accuracy.h" namespace mindspore { +namespace { +constexpr int32_t kNumThreads = 2; +} class TestCxxApiLiteModel : public mindspore::CommonTest { public: TestCxxApiLiteModel() = default; @@ -40,6 +43,7 @@ TEST_F(TestCxxApiLiteModel, test_build_graph_uninitialized_FAILED) { Model model; GraphCell graph_cell; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); context->MutableDeviceInfo().push_back(cpu_context); @@ -50,6 +54,7 @@ TEST_F(TestCxxApiLiteModel, test_build_SUCCES) { Model model; Graph graph; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); context->MutableDeviceInfo().push_back(cpu_context); @@ -66,6 +71,7 @@ TEST_F(TestCxxApiLiteModel, test_train_mode_SUCCES) { Model model; Graph graph; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); context->MutableDeviceInfo().push_back(cpu_context); @@ -85,6 +91,7 @@ TEST_F(TestCxxApiLiteModel, test_outputs_SUCCESS) { Model model; Graph graph; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); context->MutableDeviceInfo().push_back(cpu_context); @@ -106,6 +113,7 @@ TEST_F(TestCxxApiLiteModel, test_metrics_SUCCESS) { Model model; Graph graph; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); context->MutableDeviceInfo().push_back(cpu_context); @@ -121,6 +129,7 @@ TEST_F(TestCxxApiLiteModel, test_getparams_SUCCESS) { Model model; Graph graph; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); context->MutableDeviceInfo().push_back(cpu_context); auto train_cfg = std::make_shared(); @@ -157,6 +166,7 @@ TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) { Model model; Graph graph; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); context->MutableDeviceInfo().push_back(cpu_context); auto train_cfg = std::make_shared(); @@ -188,6 +198,7 @@ TEST_F(TestCxxApiLiteModel, test_fp32_SUCCESS) { Model model; Graph graph; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); cpu_context->SetEnableFP16(true); context->MutableDeviceInfo().push_back(cpu_context); @@ -201,6 +212,7 @@ TEST_F(TestCxxApiLiteModel, test_fp16_SUCCESS) { Model model; Graph graph; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); cpu_context->SetEnableFP16(true); context->MutableDeviceInfo().push_back(cpu_context); @@ -216,6 +228,7 @@ TEST_F(TestCxxApiLiteModel, set_weights_FAILURE) { Model model; Graph graph; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); cpu_context->SetEnableFP16(true); context->MutableDeviceInfo().push_back(cpu_context); @@ -239,6 +252,7 @@ TEST_F(TestCxxApiLiteModel, set_get_lr_SUCCESS) { Graph graph; float learn_rate = 0.2; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); auto cpu_context = std::make_shared(); cpu_context->SetEnableFP16(true); context->MutableDeviceInfo().push_back(cpu_context); diff --git a/tests/ut/cpp/cxx_api/context_test.cc b/tests/ut/cpp/cxx_api/context_test.cc index 7e50a2f4589..9322e2e3828 100644 --- a/tests/ut/cpp/cxx_api/context_test.cc +++ b/tests/ut/cpp/cxx_api/context_test.cc @@ -18,6 +18,9 @@ #include "include/api/context.h" namespace mindspore { +namespace { +constexpr int32_t kNumThreads = 2; +} class TestCxxApiContext : public UT::Common { public: TestCxxApiContext() = default; @@ -59,6 +62,7 @@ TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) { TEST_F(TestCxxApiContext, test_context_cpu_context_SUCCESS) { auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); std::shared_ptr cpu = std::make_shared(); cpu->SetEnableFP16(true); context->MutableDeviceInfo().push_back(cpu); @@ -81,6 +85,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) { std::string option_9_ans = "1,2,3,4,5"; auto context = std::make_shared(); + context->SetThreadNum(kNumThreads); std::shared_ptr ascend310 = std::make_shared(); ascend310->SetInputShape(option_1); ascend310->SetInsertOpConfigPath(option_2);