diff --git a/mindspore/lite/include/context.h b/mindspore/lite/include/context.h index 3c8bbf79b1b..e9f71366630 100644 --- a/mindspore/lite/include/context.h +++ b/mindspore/lite/include/context.h @@ -37,13 +37,35 @@ typedef enum { DT_NPU /**< NPU device type, not supported yet */ } DeviceType; +/// \brief CpuDeviceInfo defined for CPU's configuration information. +typedef struct { + bool enable_float16_ = false; /**< prior enable float16 inference */ + CpuBindMode cpu_bind_mode_ = MID_CPU; +} CpuDeviceInfo; + +/// \brief GpuDeviceInfo defined for GPU's configuration information. +typedef struct { + bool enable_float16_ = false; /**< prior enable float16 inference */ +} GpuDeviceInfo; + +/// \brief DeviceInfo defined for backend's configuration information. +union DeviceInfo { + CpuDeviceInfo cpu_device_info_; + GpuDeviceInfo gpu_device_info_; +}; + +/// \brief DeviceContext defined for holding backend's configuration information. +struct DeviceContext { + DeviceType device_type_ = DT_CPU; + DeviceInfo device_info_; +}; + /// \brief Context defined for holding environment variables during runtime. struct Context { - bool enable_float16_ = false; /**< prior enable float16 inference */ - DeviceType device_type_ = DT_CPU; + std::string vendor_name_; int thread_num_ = 2; /**< thread number config for thread pool */ AllocatorPtr allocator = nullptr; - CpuBindMode cpu_bind_mode_ = MID_CPU; + DeviceContextVector device_list_ = {{DT_CPU, {false, MID_CPU}}}; }; } // namespace mindspore::lite #endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_ diff --git a/mindspore/lite/include/lite_utils.h b/mindspore/lite/include/lite_utils.h index 311647ba072..9a86b054ef0 100644 --- a/mindspore/lite/include/lite_utils.h +++ b/mindspore/lite/include/lite_utils.h @@ -27,7 +27,11 @@ namespace mindspore::lite { /// \note List public class and interface for reference. class Allocator; +/// \brief DeviceContext defined a device context. +struct DeviceContext; + using TensorPtrVector = std::vector; +using DeviceContextVector = std::vector; using Uint32Vector = std::vector; using String = std::string; using NodeType = schema::NodeType; diff --git a/mindspore/lite/src/inner_context.cc b/mindspore/lite/src/inner_context.cc index 085bf19bf63..7650f3aac6a 100644 --- a/mindspore/lite/src/inner_context.cc +++ b/mindspore/lite/src/inner_context.cc @@ -20,8 +20,13 @@ namespace mindspore::lite { int InnerContext::Init() { - if (this->thread_pool_ == nullptr) { - this->thread_pool_ = CreateLiteThreadPool(this->thread_num_, this->cpu_bind_mode_); + if (this->device_list_.empty()) { + MS_LOG(ERROR) << "Device list is empty."; + return RET_NOT_SUPPORT; + } + if (this->thread_pool_ == nullptr && this->device_list_[0].device_type_ == DT_CPU) { + this->thread_pool_ = + CreateLiteThreadPool(this->thread_num_, this->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_); if (this->thread_pool_ == nullptr) { MS_LOG(ERROR) << "Create ThreadPool failed"; return RET_NULL_PTR; diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 4706c1e3f7d..649f0e1d12e 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -315,14 +315,27 @@ int LiteSession::Init(Context *context) { return RET_ERROR; } - MS_ASSERT(nullptr != context); - if (context->device_type_ == DT_NPU) { + if (context == nullptr) { + MS_LOG(ERROR) << "context is nullptr"; + is_running_.store(false); + return RET_NULL_PTR; + } + + if (context->device_list_.empty()) { + MS_LOG(ERROR) << "Device list is empty."; + is_running_.store(false); + return RET_NOT_SUPPORT; + } + + auto &device_type = context->device_list_[0].device_type_; + + if (device_type == DT_NPU) { MS_LOG(ERROR) << "NPU is not supported."; is_running_.store(false); return RET_NOT_SUPPORT; } #ifndef SUPPORT_GPU - if (context->device_type_ == DT_GPU) { + if (device_type == DT_GPU) { MS_LOG(ERROR) << "GPU is not supported."; is_running_.store(false); return RET_NOT_SUPPORT; @@ -337,9 +350,10 @@ int LiteSession::Init(Context *context) { } this->context_->allocator = context->allocator; this->context_->thread_num_ = context->thread_num_; - this->context_->cpu_bind_mode_ = context->cpu_bind_mode_; - this->context_->device_type_ = context->device_type_; - this->context_->enable_float16_ = context->enable_float16_; + this->context_->device_list_.clear(); + for (auto &device_ctx : context->device_list_) { + this->context_->device_list_.push_back(device_ctx); + } auto ret = this->context_->Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init Context failed"; @@ -353,11 +367,12 @@ int LiteSession::Init(Context *context) { return ret; } #if SUPPORT_GPU - if (context_->device_type_ == DT_GPU) { + if (device_type == DT_GPU) { + auto gpu_device_info = this->context_->device_list_[0].device_info_.gpu_device_info_; auto opencl_runtime = ocl_runtime_wrap_.GetInstance(); - opencl_runtime->SetFp16Enable(context_->enable_float16_); + opencl_runtime->SetFp16Enable(gpu_device_info.enable_float16_); if (opencl_runtime->Init() != RET_OK) { - context_->device_type_ = DT_CPU; + device_type = DT_CPU; MS_LOG(WARNING) << "Init OpenCL runtime failed, change to CPU mode."; } else { MS_LOG(INFO) << "Init OpenCL runtime success."; @@ -375,9 +390,18 @@ int LiteSession::Init(Context *context) { } void LiteSession::BindThread(bool if_bind) { - if (this->context_->cpu_bind_mode_ != NO_BIND) { + if (this->context_->device_list_.empty()) { + MS_LOG(ERROR) << "Device list is empty."; + return; + } + auto &device_ctx = this->context_->device_list_[0]; + if (device_ctx.device_type_ != DT_CPU) { + MS_LOG(ERROR) << "Device is not CPU."; + return; + } + if (device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ != NO_BIND) { MS_ASSERT(this->context_->thread_pool_ != NULL); - BindThreads(this->context_->thread_pool_, if_bind, this->context_->cpu_bind_mode_); + BindThreads(this->context_->thread_pool_, if_bind, device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_); } } diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index d93cd5ced75..45e15d32867 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -192,10 +192,12 @@ void Scheduler::ConstructSubgraphs(std::vector *kernels) { std::vector subgraph_kernels; size_t sub_cnt{0}; + auto &device_ctx = context_->device_list_[0]; for (auto temp_kernels : sub_kernels_list) { std::vector output_tensor = kernel::LiteKernelUtil::SubgraphOutputTensors(temp_kernels); for (auto tensor : output_tensor) { - if (context_->enable_float16_ && tensor->data_type() == kNumberTypeFloat16) { + if (device_ctx.device_type_ == DT_CPU && device_ctx.device_info_.cpu_device_info_.enable_float16_ && + tensor->data_type() == kNumberTypeFloat16) { tensor->set_data_type(kNumberTypeFloat32); } } @@ -246,8 +248,9 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector &in_tens MS_ASSERT(primitive != nullptr); TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast(primitive->Type())}; + auto &device_ctx = context_->device_list_[0]; #if SUPPORT_GPU - if (context_->device_type_ == DT_GPU) { + if (device_ctx.device_type_ == DT_GPU) { desc.arch = kernel::KERNEL_ARCH::kGPU; auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); if (kernel != nullptr) { @@ -262,7 +265,8 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector &in_tens #endif desc.arch = kernel::KERNEL_ARCH::kCPU; kernel::LiteKernel *kernel = nullptr; - if ((context_->enable_float16_ && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16) { + if ((device_ctx.device_info_.cpu_device_info_.enable_float16_ && data_type == kNumberTypeFloat32) || + data_type == kNumberTypeFloat16) { // check if support fp16 kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type}; kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key); diff --git a/mindspore/lite/test/ut/src/infer_test.cc b/mindspore/lite/test/ut/src/infer_test.cc index 53ad644500e..e48a4c26649 100644 --- a/mindspore/lite/test/ut/src/infer_test.cc +++ b/mindspore/lite/test/ut/src/infer_test.cc @@ -106,8 +106,9 @@ TEST_F(InferTest, TestConvNode) { meta_graph.reset(); content = nullptr; auto context = new lite::InnerContext; - context->cpu_bind_mode_ = lite::NO_BIND; - context->device_type_ = lite::DT_CPU; + auto &device_list = context->device_list_; + lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}}; + device_list.push_back(device_ctx); context->thread_num_ = 4; ASSERT_EQ(lite::RET_OK, context->Init()); auto session = session::LiteSession::CreateSession(context); @@ -205,8 +206,9 @@ TEST_F(InferTest, TestAddNode) { meta_graph.reset(); content = nullptr; auto context = new lite::InnerContext; - context->cpu_bind_mode_ = lite::NO_BIND; - context->device_type_ = lite::DT_CPU; + auto &device_list = context->device_list_; + lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}}; + device_list.push_back(device_ctx); context->thread_num_ = 4; ASSERT_EQ(lite::RET_OK, context->Init()); auto session = session::LiteSession::CreateSession(context); @@ -295,8 +297,9 @@ TEST_F(InferTest, TestParallelExecutor) { meta_graph.reset(); content = nullptr; auto context = new lite::InnerContext; - context->cpu_bind_mode_ = lite::NO_BIND; - context->device_type_ = lite::DT_CPU; + auto &device_list = context->device_list_; + lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}}; + device_list.push_back(device_ctx); context->thread_num_ = 4; ASSERT_EQ(lite::RET_OK, context->Init()); auto session = new SessionWithParallelExecutor(); @@ -336,8 +339,7 @@ TEST_F(InferTest, TestModel) { ASSERT_NE(nullptr, model); delete[] buf[0]; auto context = new lite::InnerContext; - context->cpu_bind_mode_ = lite::NO_BIND; - context->device_type_ = lite::DT_CPU; + context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context->thread_num_ = 4; ASSERT_EQ(lite::RET_OK, context->Init()); auto session = session::LiteSession::CreateSession(context); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc index 8726b3c88c2..9cb40c05d44 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/bn_grad_fp32_test.cc @@ -68,7 +68,6 @@ TEST_F(TestBNGradFp32, BNGradFp32) { std::vector outputs = {&dx_tensor, &dscale_tensor, &dbias_tensor}; lite::InnerContext ctx; - ctx.device_type_ = lite::DT_CPU; ctx.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, ctx.Init()); @@ -171,7 +170,6 @@ TEST_F(TestBNGradFp32, BNTtrainFp32) { kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_FusedBatchNorm}; mindspore::lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc index 006a6ff81f1..b3c1d60df4a 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/convolution_grad_fp32_tests.cc @@ -108,7 +108,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32FilterGrad) { std::vector outputs = {&dw_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -182,7 +181,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32InputGrad) { uint64_t time_avg = 0; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -255,7 +253,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) { std::vector outputs = {&dw_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -328,7 +325,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupInputGrad) { uint64_t time_avg = 0; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -401,7 +397,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationFilterGrad) { std::vector outputs = {&dw_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -474,7 +469,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationInputGrad) { uint64_t time_avg = 0; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -542,7 +536,6 @@ TEST_F(TestConvolutionGradFp32, ConvGroupDilation) { uint64_t time_avg = 0; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -644,7 +637,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32Dilation2Group2Stride2FilterGrad) { std::vector outputs = {&dw_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -745,7 +737,6 @@ TEST_F(TestConvolutionGradFp32, ConvGroup2Dilation2Stride2) { uint64_t time_avg = 0; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc index 5ca7276c6f4..96be0e6cf2d 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/deconvolution_grad_fp32_tests.cc @@ -90,7 +90,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32FilterGrad) { std::vector outputs = {&dw_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -190,7 +189,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2FilterGrad) { std::vector outputs = {&dw_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -290,7 +288,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group3FilterGrad) { std::vector outputs = {&dw_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -390,7 +387,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group3Stride1FilterGrad) { std::vector outputs = {&dw_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -490,7 +486,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group2Stride2FilterGrad) { std::vector outputs = {&dw_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -590,7 +585,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group12Stride2FilterGrad) { std::vector outputs = {&dw_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc index 0c4b6367a96..7cb310e0aa3 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/network_test.cc @@ -357,8 +357,7 @@ TEST_F(NetworkTest, tuning_layer) { meta_graph.reset(); content = nullptr; lite::Context context; - context.device_type_ = lite::DT_CPU; - context.cpu_bind_mode_ = lite::NO_BIND; + context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context.thread_num_ = 1; auto session = session::TrainSession::CreateSession(&context); ASSERT_NE(nullptr, session); @@ -518,8 +517,7 @@ TEST_F(NetworkTest, efficient_net) { auto model = lite::TrainModel::Import(buf, net_size); delete[] buf; auto context = new lite::Context; - context->device_type_ = lite::DT_CPU; - context->cpu_bind_mode_ = lite::NO_BIND; + context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context->thread_num_ = 1; auto session = session::TrainSession::CreateSession(context); @@ -544,8 +542,7 @@ TEST_F(NetworkTest, lenetnet) { auto model = lite::TrainModel::Import(buf, net_size); delete[] buf; auto context = new lite::Context; - context->device_type_ = lite::DT_CPU; - context->cpu_bind_mode_ = lite::NO_BIND; + context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context->thread_num_ = 1; // check registration @@ -589,8 +586,7 @@ TEST_F(NetworkTest, retina_net) { auto model = lite::Model::Import(buf, net_size); delete[] buf; auto context = new lite::Context; - context->device_type_ = lite::DT_CPU; - context->cpu_bind_mode_ = lite::NO_BIND; + context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context->thread_num_ = 1; // auto session = session::TrainSession::CreateSession(context); @@ -640,8 +636,7 @@ TEST_F(NetworkTest, mobileface_net) { auto model = lite::Model::Import(buf, net_size); delete[] buf; auto context = new lite::Context; - context->device_type_ = lite::DT_CPU; - context->cpu_bind_mode_ = lite::NO_BIND; + context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND; context->thread_num_ = 1; // auto session = session::TrainSession::CreateSession(context); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc index b1e90b8f9d3..70256d1054a 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc @@ -141,7 +141,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) { std::vector outputs = {&dx_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -206,7 +205,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolingBatchGradFp32) { std::vector outputs = {&dx_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -269,7 +267,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride2Fp32) { std::vector outputs = {&out_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -334,7 +331,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride3Fp32) { std::vector outputs = {&out_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -455,7 +451,6 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradBatchFp32) { std::vector maxpool_outputs = {&out_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -530,7 +525,6 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride2Fp32) { std::vector maxpool_outputs = {&out_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); @@ -605,7 +599,6 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride3Fp32) { std::vector maxpool_outputs = {&out_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc index 14c5777f58c..e1254e72021 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/softmax_crossentropy_fp32_tests.cc @@ -61,7 +61,6 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { std::vector outputs = {&loss_tensor, &grad_tensor}; lite::InnerContext context; - context.device_type_ = lite::DT_CPU; context.thread_num_ = 1; ASSERT_EQ(lite::RET_OK, context.Init()); diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index 178774d9fae..fd6100077c2 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -378,21 +378,28 @@ int Benchmark::RunBenchmark() { std::cerr << "New context failed while running " << model_name.c_str() << std::endl; return RET_ERROR; } + auto &device_ctx = context->device_list_[0]; if (flags_->device_ == "CPU") { - context->device_type_ = lite::DT_CPU; + device_ctx.device_type_ = lite::DT_CPU; } else if (flags_->device_ == "GPU") { - context->device_type_ = lite::DT_GPU; + device_ctx.device_type_ = lite::DT_GPU; } - if (flags_->cpu_bind_mode_ == 2) { - context->cpu_bind_mode_ = MID_CPU; - } else if (flags_->cpu_bind_mode_ == 1) { - context->cpu_bind_mode_ = HIGHER_CPU; - } else { - context->cpu_bind_mode_ = NO_BIND; + if (device_ctx.device_type_ == DT_CPU) { + if (flags_->cpu_bind_mode_ == MID_CPU) { + device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU; + } else if (flags_->cpu_bind_mode_ == HIGHER_CPU) { + device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; + } else { + device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; + } + device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_; + } + if (device_ctx.device_type_ == DT_GPU) { + device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_; } context->thread_num_ = flags_->num_threads_; - context->enable_float16_ = flags_->enable_fp16_; + session_ = session::LiteSession::CreateSession(context.get()); if (session_ == nullptr) { MS_LOG(ERROR) << "CreateSession failed while running ", model_name.c_str(); diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 719f7094898..14373d5a389 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -1337,9 +1337,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { auto model = lite::Model::Import(content, size); Context ctx; - ctx.device_type_ = DT_CPU; ctx.thread_num_ = calibrator_->GetThreadNum(); - ctx.cpu_bind_mode_ = MID_CPU; fp32_session_ = dynamic_cast(session::LiteSession::CreateSession(&ctx)); if (fp32_session_ == nullptr) { @@ -1420,9 +1418,8 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { auto int8_model = lite::Model::Import(int8_content, size); Context int8_ctx; - int8_ctx.device_type_ = DT_CPU; int8_ctx.thread_num_ = calibrator_->GetThreadNum(); - int8_ctx.cpu_bind_mode_ = HIGHER_CPU; + int8_ctx.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; int8_session_ = dynamic_cast(session::LiteSession::CreateSession(&int8_ctx)); if (int8_session_ == nullptr) {