modify default thread num
This commit is contained in:
parent
6ac3c935b8
commit
817c40acc0
|
@ -61,6 +61,7 @@ public class ModelTest {
|
|||
Graph g = new Graph();
|
||||
assertTrue(g.load("../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_train.ms"));
|
||||
MSContext context = new MSContext();
|
||||
context.init(1, 0);
|
||||
TrainCfg cfg = new TrainCfg();
|
||||
Model liteModel = new Model();
|
||||
boolean isSuccess = liteModel.build(g, context, cfg);
|
||||
|
@ -74,7 +75,7 @@ public class ModelTest {
|
|||
Graph g = new Graph();
|
||||
assertTrue(g.load(modelFile));
|
||||
MSContext context = new MSContext();
|
||||
context.init();
|
||||
context.init(1,0);
|
||||
context.addDeviceInfo(DeviceType.DT_CPU, false, 0);
|
||||
Model liteModel = new Model();
|
||||
boolean isSuccess = liteModel.build(g, context, null);
|
||||
|
@ -119,6 +120,7 @@ public class ModelTest {
|
|||
public void testBuildByFileFailed() {
|
||||
String modelFile = "../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_tod_infer.ms";
|
||||
MSContext context = new MSContext();
|
||||
context.init(1, 0);
|
||||
Model liteModel = new Model();
|
||||
boolean isSuccess = liteModel.build(modelFile, 0, context);
|
||||
assertFalse(isSuccess);
|
||||
|
|
|
@ -33,7 +33,7 @@ struct Context::Data {
|
|||
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
||||
|
||||
#ifdef PARALLEL_INFERENCE
|
||||
int32_t thread_num = 8;
|
||||
int32_t thread_num = 0; // defaults are automatically adjusted based on computer performance
|
||||
int affinity_mode_ = 1;
|
||||
int32_t inter_op_parallel_num_ = 4;
|
||||
#else
|
||||
|
|
|
@ -296,10 +296,19 @@ std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<R
|
|||
MS_LOG(ERROR) << "user set config context nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
if (context->GetThreadNum() <= 0) {
|
||||
if (context->GetThreadNum() < 0) {
|
||||
MS_LOG(ERROR) << "Invalid thread num " << context->GetThreadNum();
|
||||
return nullptr;
|
||||
}
|
||||
if (context->GetThreadNum() == 0) {
|
||||
// Defaults are automatically adjusted based on computer performance
|
||||
auto thread_num = GetDefaultThreadNum();
|
||||
if (thread_num == 0) {
|
||||
MS_LOG(ERROR) << "computer thread num failed.";
|
||||
return nullptr;
|
||||
}
|
||||
context->SetThreadNum(thread_num);
|
||||
}
|
||||
if (!context->GetThreadAffinityCoreList().empty()) {
|
||||
MS_LOG(ERROR) << "parallel predict not support user set core list.";
|
||||
return nullptr;
|
||||
|
|
|
@ -116,51 +116,51 @@ def test_context_04():
|
|||
|
||||
def test_context_05():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context(thread_affinity_mode="1")
|
||||
context = mslite.Context(thread_num=2, thread_affinity_mode="1")
|
||||
assert "thread_affinity_mode must be int" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_06():
|
||||
context = mslite.Context(thread_affinity_mode=2)
|
||||
context = mslite.Context(thread_num=2, thread_affinity_mode=2)
|
||||
assert "thread_affinity_mode: 2" in str(context)
|
||||
|
||||
|
||||
def test_context_07():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context(thread_affinity_core_list=2)
|
||||
context = mslite.Context(thread_num=2, thread_affinity_core_list=2)
|
||||
assert "thread_affinity_core_list must be list" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_08():
|
||||
context = mslite.Context(thread_affinity_core_list=[2])
|
||||
context = mslite.Context(thread_num=2, thread_affinity_core_list=[2])
|
||||
assert "thread_affinity_core_list: [2]" in str(context)
|
||||
|
||||
|
||||
def test_context_09():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context(thread_affinity_core_list=["1", "0"])
|
||||
context = mslite.Context(thread_num=2, thread_affinity_core_list=["1", "0"])
|
||||
assert "thread_affinity_core_list element must be int" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_10():
|
||||
context = mslite.Context(thread_affinity_core_list=[1, 0])
|
||||
context = mslite.Context(thread_num=2, thread_affinity_core_list=[1, 0])
|
||||
assert "thread_affinity_core_list: [1, 0]" in str(context)
|
||||
|
||||
|
||||
def test_context_11():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context(enable_parallel=1)
|
||||
context = mslite.Context(thread_num=2, enable_parallel=1)
|
||||
assert "enable_parallel must be bool" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_context_12():
|
||||
context = mslite.Context(enable_parallel=True)
|
||||
context = mslite.Context(thread_num=2, enable_parallel=True)
|
||||
assert "enable_parallel: True" in str(context)
|
||||
|
||||
|
||||
def test_context_13():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context()
|
||||
context = mslite.Context(thread_num=2)
|
||||
context.append_device_info("CPUDeviceInfo")
|
||||
assert "device_info must be DeviceInfo" in str(raise_info.value)
|
||||
|
||||
|
@ -168,7 +168,7 @@ def test_context_13():
|
|||
def test_context_14():
|
||||
gpu_device_info = mslite.GPUDeviceInfo()
|
||||
cpu_device_info = mslite.CPUDeviceInfo()
|
||||
context = mslite.Context()
|
||||
context = mslite.Context(thread_num=2)
|
||||
context.append_device_info(gpu_device_info)
|
||||
context.append_device_info(cpu_device_info)
|
||||
assert "device_list: 1, 0" in str(context)
|
||||
|
@ -293,7 +293,7 @@ def test_model_01():
|
|||
|
||||
def test_model_build_01():
|
||||
with pytest.raises(TypeError) as raise_info:
|
||||
context = mslite.Context()
|
||||
context = mslite.Context(thread_num=2)
|
||||
model = mslite.Model()
|
||||
model.build_from_file(model_path=1, model_type=mslite.ModelType.MINDIR_LITE, context=context)
|
||||
assert "model_path must be str" in str(raise_info.value)
|
||||
|
@ -317,7 +317,7 @@ def test_model_build_03():
|
|||
|
||||
def test_model_build_04():
|
||||
with pytest.raises(RuntimeError) as raise_info:
|
||||
context = mslite.Context()
|
||||
context = mslite.Context(thread_num=2)
|
||||
model = mslite.Model()
|
||||
model.build_from_file(model_path="test.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context)
|
||||
assert "model_path does not exist" in str(raise_info.value)
|
||||
|
@ -325,7 +325,7 @@ def test_model_build_04():
|
|||
|
||||
def get_model():
|
||||
cpu_device_info = mslite.CPUDeviceInfo()
|
||||
context = mslite.Context()
|
||||
context = mslite.Context(thread_num=2)
|
||||
context.append_device_info(cpu_device_info)
|
||||
model = mslite.Model()
|
||||
model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context)
|
||||
|
|
|
@ -129,6 +129,7 @@ TEST_F(ModelParallelRunnerTest, PredictWithoutInput) {
|
|||
|
||||
auto context = std::make_shared<Context>();
|
||||
ASSERT_NE(nullptr, context);
|
||||
context->SetThreadNum(2);
|
||||
auto &device_list = context->MutableDeviceInfo();
|
||||
auto device_info = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
ASSERT_NE(nullptr, device_info);
|
||||
|
|
|
@ -21,6 +21,9 @@
|
|||
#include "include/api/metrics/accuracy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr int32_t kNumThreads = 2;
|
||||
}
|
||||
class TestCxxApiLiteModel : public mindspore::CommonTest {
|
||||
public:
|
||||
TestCxxApiLiteModel() = default;
|
||||
|
@ -40,6 +43,7 @@ TEST_F(TestCxxApiLiteModel, test_build_graph_uninitialized_FAILED) {
|
|||
Model model;
|
||||
GraphCell graph_cell;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
||||
|
@ -50,6 +54,7 @@ TEST_F(TestCxxApiLiteModel, test_build_SUCCES) {
|
|||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
||||
|
@ -66,6 +71,7 @@ TEST_F(TestCxxApiLiteModel, test_train_mode_SUCCES) {
|
|||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
||||
|
@ -85,6 +91,7 @@ TEST_F(TestCxxApiLiteModel, test_outputs_SUCCESS) {
|
|||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
||||
|
@ -106,6 +113,7 @@ TEST_F(TestCxxApiLiteModel, test_metrics_SUCCESS) {
|
|||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
||||
|
@ -121,6 +129,7 @@ TEST_F(TestCxxApiLiteModel, test_getparams_SUCCESS) {
|
|||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
auto train_cfg = std::make_shared<TrainCfg>();
|
||||
|
@ -157,6 +166,7 @@ TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) {
|
|||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
auto train_cfg = std::make_shared<TrainCfg>();
|
||||
|
@ -188,6 +198,7 @@ TEST_F(TestCxxApiLiteModel, test_fp32_SUCCESS) {
|
|||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
cpu_context->SetEnableFP16(true);
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
@ -201,6 +212,7 @@ TEST_F(TestCxxApiLiteModel, test_fp16_SUCCESS) {
|
|||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
cpu_context->SetEnableFP16(true);
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
@ -216,6 +228,7 @@ TEST_F(TestCxxApiLiteModel, set_weights_FAILURE) {
|
|||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
cpu_context->SetEnableFP16(true);
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
@ -239,6 +252,7 @@ TEST_F(TestCxxApiLiteModel, set_get_lr_SUCCESS) {
|
|||
Graph graph;
|
||||
float learn_rate = 0.2;
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
cpu_context->SetEnableFP16(true);
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
|
|
|
@ -18,6 +18,9 @@
|
|||
#include "include/api/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr int32_t kNumThreads = 2;
|
||||
}
|
||||
class TestCxxApiContext : public UT::Common {
|
||||
public:
|
||||
TestCxxApiContext() = default;
|
||||
|
@ -59,6 +62,7 @@ TEST_F(TestCxxApiContext, test_context_get_set_SUCCESS) {
|
|||
|
||||
TEST_F(TestCxxApiContext, test_context_cpu_context_SUCCESS) {
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
std::shared_ptr<CPUDeviceInfo> cpu = std::make_shared<CPUDeviceInfo>();
|
||||
cpu->SetEnableFP16(true);
|
||||
context->MutableDeviceInfo().push_back(cpu);
|
||||
|
@ -81,6 +85,7 @@ TEST_F(TestCxxApiContext, test_context_ascend_context_FAILED) {
|
|||
std::string option_9_ans = "1,2,3,4,5";
|
||||
|
||||
auto context = std::make_shared<Context>();
|
||||
context->SetThreadNum(kNumThreads);
|
||||
std::shared_ptr<AscendDeviceInfo> ascend310 = std::make_shared<AscendDeviceInfo>();
|
||||
ascend310->SetInputShape(option_1);
|
||||
ascend310->SetInsertOpConfigPath(option_2);
|
||||
|
|
Loading…
Reference in New Issue