forked from mindspore-Ecosystem/mindspore
add device_ctx to context
This commit is contained in:
parent
57ebdb4545
commit
43633e28a6
|
@ -37,13 +37,35 @@ typedef enum {
|
|||
DT_NPU /**< NPU device type, not supported yet */
|
||||
} DeviceType;
|
||||
|
||||
/// \brief CpuDeviceInfo defined for CPU's configuration information.
|
||||
typedef struct {
|
||||
bool enable_float16_ = false; /**< prior enable float16 inference */
|
||||
CpuBindMode cpu_bind_mode_ = MID_CPU;
|
||||
} CpuDeviceInfo;
|
||||
|
||||
/// \brief GpuDeviceInfo defined for GPU's configuration information.
|
||||
typedef struct {
|
||||
bool enable_float16_ = false; /**< prior enable float16 inference */
|
||||
} GpuDeviceInfo;
|
||||
|
||||
/// \brief DeviceInfo defined for backend's configuration information.
|
||||
union DeviceInfo {
|
||||
CpuDeviceInfo cpu_device_info_;
|
||||
GpuDeviceInfo gpu_device_info_;
|
||||
};
|
||||
|
||||
/// \brief DeviceContext defined for holding backend's configuration information.
|
||||
struct DeviceContext {
|
||||
DeviceType device_type_ = DT_CPU;
|
||||
DeviceInfo device_info_;
|
||||
};
|
||||
|
||||
/// \brief Context defined for holding environment variables during runtime.
|
||||
struct Context {
|
||||
bool enable_float16_ = false; /**< prior enable float16 inference */
|
||||
DeviceType device_type_ = DT_CPU;
|
||||
std::string vendor_name_;
|
||||
int thread_num_ = 2; /**< thread number config for thread pool */
|
||||
AllocatorPtr allocator = nullptr;
|
||||
CpuBindMode cpu_bind_mode_ = MID_CPU;
|
||||
DeviceContextVector device_list_ = {{DT_CPU, {false, MID_CPU}}};
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_
|
||||
|
|
|
@ -27,7 +27,11 @@ namespace mindspore::lite {
|
|||
/// \note List public class and interface for reference.
|
||||
class Allocator;
|
||||
|
||||
/// \brief DeviceContext defined a device context.
|
||||
struct DeviceContext;
|
||||
|
||||
using TensorPtrVector = std::vector<mindspore::schema::Tensor *>;
|
||||
using DeviceContextVector = std::vector<DeviceContext>;
|
||||
using Uint32Vector = std::vector<uint32_t>;
|
||||
using String = std::string;
|
||||
using NodeType = schema::NodeType;
|
||||
|
|
|
@ -20,8 +20,13 @@
|
|||
|
||||
namespace mindspore::lite {
|
||||
int InnerContext::Init() {
|
||||
if (this->thread_pool_ == nullptr) {
|
||||
this->thread_pool_ = CreateLiteThreadPool(this->thread_num_, this->cpu_bind_mode_);
|
||||
if (this->device_list_.empty()) {
|
||||
MS_LOG(ERROR) << "Device list is empty.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
if (this->thread_pool_ == nullptr && this->device_list_[0].device_type_ == DT_CPU) {
|
||||
this->thread_pool_ =
|
||||
CreateLiteThreadPool(this->thread_num_, this->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_);
|
||||
if (this->thread_pool_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create ThreadPool failed";
|
||||
return RET_NULL_PTR;
|
||||
|
|
|
@ -315,14 +315,27 @@ int LiteSession::Init(Context *context) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
MS_ASSERT(nullptr != context);
|
||||
if (context->device_type_ == DT_NPU) {
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "context is nullptr";
|
||||
is_running_.store(false);
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (context->device_list_.empty()) {
|
||||
MS_LOG(ERROR) << "Device list is empty.";
|
||||
is_running_.store(false);
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
auto &device_type = context->device_list_[0].device_type_;
|
||||
|
||||
if (device_type == DT_NPU) {
|
||||
MS_LOG(ERROR) << "NPU is not supported.";
|
||||
is_running_.store(false);
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
#ifndef SUPPORT_GPU
|
||||
if (context->device_type_ == DT_GPU) {
|
||||
if (device_type == DT_GPU) {
|
||||
MS_LOG(ERROR) << "GPU is not supported.";
|
||||
is_running_.store(false);
|
||||
return RET_NOT_SUPPORT;
|
||||
|
@ -337,9 +350,10 @@ int LiteSession::Init(Context *context) {
|
|||
}
|
||||
this->context_->allocator = context->allocator;
|
||||
this->context_->thread_num_ = context->thread_num_;
|
||||
this->context_->cpu_bind_mode_ = context->cpu_bind_mode_;
|
||||
this->context_->device_type_ = context->device_type_;
|
||||
this->context_->enable_float16_ = context->enable_float16_;
|
||||
this->context_->device_list_.clear();
|
||||
for (auto &device_ctx : context->device_list_) {
|
||||
this->context_->device_list_.push_back(device_ctx);
|
||||
}
|
||||
auto ret = this->context_->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init Context failed";
|
||||
|
@ -353,11 +367,12 @@ int LiteSession::Init(Context *context) {
|
|||
return ret;
|
||||
}
|
||||
#if SUPPORT_GPU
|
||||
if (context_->device_type_ == DT_GPU) {
|
||||
if (device_type == DT_GPU) {
|
||||
auto gpu_device_info = this->context_->device_list_[0].device_info_.gpu_device_info_;
|
||||
auto opencl_runtime = ocl_runtime_wrap_.GetInstance();
|
||||
opencl_runtime->SetFp16Enable(context_->enable_float16_);
|
||||
opencl_runtime->SetFp16Enable(gpu_device_info.enable_float16_);
|
||||
if (opencl_runtime->Init() != RET_OK) {
|
||||
context_->device_type_ = DT_CPU;
|
||||
device_type = DT_CPU;
|
||||
MS_LOG(WARNING) << "Init OpenCL runtime failed, change to CPU mode.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "Init OpenCL runtime success.";
|
||||
|
@ -375,9 +390,18 @@ int LiteSession::Init(Context *context) {
|
|||
}
|
||||
|
||||
void LiteSession::BindThread(bool if_bind) {
|
||||
if (this->context_->cpu_bind_mode_ != NO_BIND) {
|
||||
if (this->context_->device_list_.empty()) {
|
||||
MS_LOG(ERROR) << "Device list is empty.";
|
||||
return;
|
||||
}
|
||||
auto &device_ctx = this->context_->device_list_[0];
|
||||
if (device_ctx.device_type_ != DT_CPU) {
|
||||
MS_LOG(ERROR) << "Device is not CPU.";
|
||||
return;
|
||||
}
|
||||
if (device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ != NO_BIND) {
|
||||
MS_ASSERT(this->context_->thread_pool_ != NULL);
|
||||
BindThreads(this->context_->thread_pool_, if_bind, this->context_->cpu_bind_mode_);
|
||||
BindThreads(this->context_->thread_pool_, if_bind, device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -192,10 +192,12 @@ void Scheduler::ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels) {
|
|||
|
||||
std::vector<kernel::LiteKernel *> subgraph_kernels;
|
||||
size_t sub_cnt{0};
|
||||
auto &device_ctx = context_->device_list_[0];
|
||||
for (auto temp_kernels : sub_kernels_list) {
|
||||
std::vector<Tensor *> output_tensor = kernel::LiteKernelUtil::SubgraphOutputTensors(temp_kernels);
|
||||
for (auto tensor : output_tensor) {
|
||||
if (context_->enable_float16_ && tensor->data_type() == kNumberTypeFloat16) {
|
||||
if (device_ctx.device_type_ == DT_CPU && device_ctx.device_info_.cpu_device_info_.enable_float16_ &&
|
||||
tensor->data_type() == kNumberTypeFloat16) {
|
||||
tensor->set_data_type(kNumberTypeFloat32);
|
||||
}
|
||||
}
|
||||
|
@ -246,8 +248,9 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<Tensor *> &in_tens
|
|||
MS_ASSERT(primitive != nullptr);
|
||||
TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())};
|
||||
auto &device_ctx = context_->device_list_[0];
|
||||
#if SUPPORT_GPU
|
||||
if (context_->device_type_ == DT_GPU) {
|
||||
if (device_ctx.device_type_ == DT_GPU) {
|
||||
desc.arch = kernel::KERNEL_ARCH::kGPU;
|
||||
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
|
||||
if (kernel != nullptr) {
|
||||
|
@ -262,7 +265,8 @@ kernel::LiteKernel *Scheduler::ScheduleNode(const std::vector<Tensor *> &in_tens
|
|||
#endif
|
||||
desc.arch = kernel::KERNEL_ARCH::kCPU;
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
if ((context_->enable_float16_ && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16) {
|
||||
if ((device_ctx.device_info_.cpu_device_info_.enable_float16_ && data_type == kNumberTypeFloat32) ||
|
||||
data_type == kNumberTypeFloat16) {
|
||||
// check if support fp16
|
||||
kernel::KernelKey key{desc.arch, kNumberTypeFloat16, desc.type};
|
||||
kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, key);
|
||||
|
|
|
@ -106,8 +106,9 @@ TEST_F(InferTest, TestConvNode) {
|
|||
meta_graph.reset();
|
||||
content = nullptr;
|
||||
auto context = new lite::InnerContext;
|
||||
context->cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->device_type_ = lite::DT_CPU;
|
||||
auto &device_list = context->device_list_;
|
||||
lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}};
|
||||
device_list.push_back(device_ctx);
|
||||
context->thread_num_ = 4;
|
||||
ASSERT_EQ(lite::RET_OK, context->Init());
|
||||
auto session = session::LiteSession::CreateSession(context);
|
||||
|
@ -205,8 +206,9 @@ TEST_F(InferTest, TestAddNode) {
|
|||
meta_graph.reset();
|
||||
content = nullptr;
|
||||
auto context = new lite::InnerContext;
|
||||
context->cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->device_type_ = lite::DT_CPU;
|
||||
auto &device_list = context->device_list_;
|
||||
lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}};
|
||||
device_list.push_back(device_ctx);
|
||||
context->thread_num_ = 4;
|
||||
ASSERT_EQ(lite::RET_OK, context->Init());
|
||||
auto session = session::LiteSession::CreateSession(context);
|
||||
|
@ -295,8 +297,9 @@ TEST_F(InferTest, TestParallelExecutor) {
|
|||
meta_graph.reset();
|
||||
content = nullptr;
|
||||
auto context = new lite::InnerContext;
|
||||
context->cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->device_type_ = lite::DT_CPU;
|
||||
auto &device_list = context->device_list_;
|
||||
lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}};
|
||||
device_list.push_back(device_ctx);
|
||||
context->thread_num_ = 4;
|
||||
ASSERT_EQ(lite::RET_OK, context->Init());
|
||||
auto session = new SessionWithParallelExecutor();
|
||||
|
@ -336,8 +339,7 @@ TEST_F(InferTest, TestModel) {
|
|||
ASSERT_NE(nullptr, model);
|
||||
delete[] buf[0];
|
||||
auto context = new lite::InnerContext;
|
||||
context->cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->device_type_ = lite::DT_CPU;
|
||||
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->thread_num_ = 4;
|
||||
ASSERT_EQ(lite::RET_OK, context->Init());
|
||||
auto session = session::LiteSession::CreateSession(context);
|
||||
|
|
|
@ -68,7 +68,6 @@ TEST_F(TestBNGradFp32, BNGradFp32) {
|
|||
std::vector<lite::Tensor *> outputs = {&dx_tensor, &dscale_tensor, &dbias_tensor};
|
||||
|
||||
lite::InnerContext ctx;
|
||||
ctx.device_type_ = lite::DT_CPU;
|
||||
ctx.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, ctx.Init());
|
||||
|
||||
|
@ -171,7 +170,6 @@ TEST_F(TestBNGradFp32, BNTtrainFp32) {
|
|||
kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_FusedBatchNorm};
|
||||
|
||||
mindspore::lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
|
|
@ -108,7 +108,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32FilterGrad) {
|
|||
std::vector<lite::Tensor *> outputs = {&dw_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -182,7 +181,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32InputGrad) {
|
|||
uint64_t time_avg = 0;
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -255,7 +253,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) {
|
|||
std::vector<lite::Tensor *> outputs = {&dw_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -328,7 +325,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupInputGrad) {
|
|||
uint64_t time_avg = 0;
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -401,7 +397,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationFilterGrad) {
|
|||
std::vector<lite::Tensor *> outputs = {&dw_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -474,7 +469,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationInputGrad) {
|
|||
uint64_t time_avg = 0;
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -542,7 +536,6 @@ TEST_F(TestConvolutionGradFp32, ConvGroupDilation) {
|
|||
uint64_t time_avg = 0;
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -644,7 +637,6 @@ TEST_F(TestConvolutionGradFp32, ConvFp32Dilation2Group2Stride2FilterGrad) {
|
|||
std::vector<lite::Tensor *> outputs = {&dw_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -745,7 +737,6 @@ TEST_F(TestConvolutionGradFp32, ConvGroup2Dilation2Stride2) {
|
|||
uint64_t time_avg = 0;
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
|
|
@ -90,7 +90,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32FilterGrad) {
|
|||
std::vector<lite::Tensor *> outputs = {&dw_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -190,7 +189,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2FilterGrad) {
|
|||
std::vector<lite::Tensor *> outputs = {&dw_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -290,7 +288,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group3FilterGrad) {
|
|||
std::vector<lite::Tensor *> outputs = {&dw_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -390,7 +387,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group3Stride1FilterGrad) {
|
|||
std::vector<lite::Tensor *> outputs = {&dw_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -490,7 +486,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group2Stride2FilterGrad) {
|
|||
std::vector<lite::Tensor *> outputs = {&dw_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -590,7 +585,6 @@ TEST_F(TestDeConvolutionGradFp32, DeConvFp32Dilation2Group12Stride2FilterGrad) {
|
|||
std::vector<lite::Tensor *> outputs = {&dw_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
|
|
@ -357,8 +357,7 @@ TEST_F(NetworkTest, tuning_layer) {
|
|||
meta_graph.reset();
|
||||
content = nullptr;
|
||||
lite::Context context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context.thread_num_ = 1;
|
||||
auto session = session::TrainSession::CreateSession(&context);
|
||||
ASSERT_NE(nullptr, session);
|
||||
|
@ -518,8 +517,7 @@ TEST_F(NetworkTest, efficient_net) {
|
|||
auto model = lite::TrainModel::Import(buf, net_size);
|
||||
delete[] buf;
|
||||
auto context = new lite::Context;
|
||||
context->device_type_ = lite::DT_CPU;
|
||||
context->cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->thread_num_ = 1;
|
||||
|
||||
auto session = session::TrainSession::CreateSession(context);
|
||||
|
@ -544,8 +542,7 @@ TEST_F(NetworkTest, lenetnet) {
|
|||
auto model = lite::TrainModel::Import(buf, net_size);
|
||||
delete[] buf;
|
||||
auto context = new lite::Context;
|
||||
context->device_type_ = lite::DT_CPU;
|
||||
context->cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->thread_num_ = 1;
|
||||
|
||||
// check registration
|
||||
|
@ -589,8 +586,7 @@ TEST_F(NetworkTest, retina_net) {
|
|||
auto model = lite::Model::Import(buf, net_size);
|
||||
delete[] buf;
|
||||
auto context = new lite::Context;
|
||||
context->device_type_ = lite::DT_CPU;
|
||||
context->cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->thread_num_ = 1;
|
||||
|
||||
// auto session = session::TrainSession::CreateSession(context);
|
||||
|
@ -640,8 +636,7 @@ TEST_F(NetworkTest, mobileface_net) {
|
|||
auto model = lite::Model::Import(buf, net_size);
|
||||
delete[] buf;
|
||||
auto context = new lite::Context;
|
||||
context->device_type_ = lite::DT_CPU;
|
||||
context->cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = lite::NO_BIND;
|
||||
context->thread_num_ = 1;
|
||||
|
||||
// auto session = session::TrainSession::CreateSession(context);
|
||||
|
|
|
@ -141,7 +141,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) {
|
|||
std::vector<lite::Tensor *> outputs = {&dx_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -206,7 +205,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolingBatchGradFp32) {
|
|||
std::vector<lite::Tensor *> outputs = {&dx_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -269,7 +267,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride2Fp32) {
|
|||
std::vector<lite::Tensor *> outputs = {&out_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -334,7 +331,6 @@ TEST_F(TestPoolingGradFp32, AvgPoolGradStride3Fp32) {
|
|||
std::vector<lite::Tensor *> outputs = {&out_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -455,7 +451,6 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradBatchFp32) {
|
|||
std::vector<lite::Tensor *> maxpool_outputs = {&out_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -530,7 +525,6 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride2Fp32) {
|
|||
std::vector<lite::Tensor *> maxpool_outputs = {&out_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
@ -605,7 +599,6 @@ TEST_F(TestPoolingGradFp32, MaxPoolGradStride3Fp32) {
|
|||
std::vector<lite::Tensor *> maxpool_outputs = {&out_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
|
|
@ -61,7 +61,6 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) {
|
|||
std::vector<lite::Tensor *> outputs = {&loss_tensor, &grad_tensor};
|
||||
|
||||
lite::InnerContext context;
|
||||
context.device_type_ = lite::DT_CPU;
|
||||
context.thread_num_ = 1;
|
||||
ASSERT_EQ(lite::RET_OK, context.Init());
|
||||
|
||||
|
|
|
@ -378,21 +378,28 @@ int Benchmark::RunBenchmark() {
|
|||
std::cerr << "New context failed while running " << model_name.c_str() << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto &device_ctx = context->device_list_[0];
|
||||
if (flags_->device_ == "CPU") {
|
||||
context->device_type_ = lite::DT_CPU;
|
||||
device_ctx.device_type_ = lite::DT_CPU;
|
||||
} else if (flags_->device_ == "GPU") {
|
||||
context->device_type_ = lite::DT_GPU;
|
||||
device_ctx.device_type_ = lite::DT_GPU;
|
||||
}
|
||||
|
||||
if (flags_->cpu_bind_mode_ == 2) {
|
||||
context->cpu_bind_mode_ = MID_CPU;
|
||||
} else if (flags_->cpu_bind_mode_ == 1) {
|
||||
context->cpu_bind_mode_ = HIGHER_CPU;
|
||||
} else {
|
||||
context->cpu_bind_mode_ = NO_BIND;
|
||||
if (device_ctx.device_type_ == DT_CPU) {
|
||||
if (flags_->cpu_bind_mode_ == MID_CPU) {
|
||||
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = MID_CPU;
|
||||
} else if (flags_->cpu_bind_mode_ == HIGHER_CPU) {
|
||||
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
|
||||
} else {
|
||||
device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
|
||||
}
|
||||
device_ctx.device_info_.cpu_device_info_.enable_float16_ = flags_->enable_fp16_;
|
||||
}
|
||||
if (device_ctx.device_type_ == DT_GPU) {
|
||||
device_ctx.device_info_.gpu_device_info_.enable_float16_ = flags_->enable_fp16_;
|
||||
}
|
||||
context->thread_num_ = flags_->num_threads_;
|
||||
context->enable_float16_ = flags_->enable_fp16_;
|
||||
|
||||
session_ = session::LiteSession::CreateSession(context.get());
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "CreateSession failed while running ", model_name.c_str();
|
||||
|
|
|
@ -1337,9 +1337,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
auto model = lite::Model::Import(content, size);
|
||||
|
||||
Context ctx;
|
||||
ctx.device_type_ = DT_CPU;
|
||||
ctx.thread_num_ = calibrator_->GetThreadNum();
|
||||
ctx.cpu_bind_mode_ = MID_CPU;
|
||||
|
||||
fp32_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&ctx));
|
||||
if (fp32_session_ == nullptr) {
|
||||
|
@ -1420,9 +1418,8 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
auto int8_model = lite::Model::Import(int8_content, size);
|
||||
|
||||
Context int8_ctx;
|
||||
int8_ctx.device_type_ = DT_CPU;
|
||||
int8_ctx.thread_num_ = calibrator_->GetThreadNum();
|
||||
int8_ctx.cpu_bind_mode_ = HIGHER_CPU;
|
||||
int8_ctx.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
|
||||
|
||||
int8_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&int8_ctx));
|
||||
if (int8_session_ == nullptr) {
|
||||
|
|
Loading…
Reference in New Issue