fix device priority

This commit is contained in:
sunsuodong 2021-12-15 18:38:41 -08:00
parent 19dd2e21b4
commit 62230a0e78
19 changed files with 465 additions and 328 deletions

View File

@ -32,8 +32,8 @@ class ModelC {
}
}
Status Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context);
Status Build(const std::string &model_path, ModelType model_type, const ContextC *model_context);
int Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context);
int Build(const std::string &model_path, ModelType model_type, const ContextC *model_context);
Status Resize(const std::vector<MSTensor::Impl *> &inputs, const std::vector<std::vector<int64_t>> &shapes);
Status Predict(const MSTensorHandle *inputs, size_t input_num, MSTensorHandle **outputs, size_t *output_num,
@ -45,45 +45,41 @@ class ModelC {
MSTensor::Impl *GetOutputByTensorName(const std::string &name);
private:
std::shared_ptr<session::LiteSession> session_ = nullptr;
std::shared_ptr<const ContextC> context_ = nullptr;
std::unique_ptr<lite::LiteSession> session_ = nullptr;
std::unique_ptr<const ContextC> context_ = nullptr;
std::map<mindspore::tensor::MSTensor *, MSTensor::Impl *> tensor_map_;
std::vector<MSTensor::Impl *> inputs_;
std::vector<MSTensor::Impl *> outputs_;
int CreateLiteSession(const ContextC *context);
Status RunGraph(const MSKernelCallBackC &before, const MSKernelCallBackC &after);
void ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MSTensor *> tensors);
MSTensor::Impl *TensorToTensorImpl(mindspore::tensor::MSTensor *tensor);
};
Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) {
context_.reset(model_context);
lite::Context lite_context;
auto status = A2L_ConvertContext(model_context, &lite_context);
if (status != kSuccess) {
return status;
int ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) {
int ret = CreateLiteSession(model_context);
if (ret != lite::RET_OK) {
return ret;
}
session_ = std::shared_ptr<session::LiteSession>(
session::LiteSession::CreateSession(static_cast<const char *>(model_data), data_size, &lite_context));
if (session_ == nullptr) {
MS_LOG(ERROR) << "Allocate session failed.";
return kLiteNullptr;
}
return kSuccess;
return session_->LoadModelAndCompileByBuf(static_cast<const char *>(model_data), data_size);
}
Status ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) {
context_.reset(model_context);
lite::Context lite_context;
auto status = A2L_ConvertContext(model_context, &lite_context);
if (status != kSuccess) {
return status;
int ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) {
int ret = CreateLiteSession(model_context);
if (ret != lite::RET_OK) {
return ret;
}
session_ = std::shared_ptr<session::LiteSession>(lite::LiteSession::CreateSession(model_path, &lite_context));
return session_->LoadModelAndCompileByPath(model_path);
}
int ModelC::CreateLiteSession(const ContextC *context) {
context_.reset(context);
session_.reset(new (std::nothrow) lite::LiteSession());
if (session_ == nullptr) {
MS_LOG(ERROR) << "Allocate session failed.";
return kLiteError;
MS_LOG(ERROR) << "create session failed";
return kLiteMemoryFailed;
}
return kSuccess;
return session_->Init(ContextUtils::Convert(context_.get()));
}
Status ModelC::Resize(const std::vector<MSTensor::Impl *> &inputs, const std::vector<std::vector<int64_t>> &shapes) {
@ -332,7 +328,7 @@ MSStatus MSModelBuild(MSModelHandle model, const void *model_data, size_t data_s
mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
auto impl = static_cast<mindspore::ModelC *>(model);
auto ret = impl->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), context);
return static_cast<MSStatus>(ret.StatusCode());
return static_cast<MSStatus>(ret);
}
MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSModelType model_type,
@ -348,7 +344,7 @@ MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSMod
mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
auto impl = static_cast<mindspore::ModelC *>(model);
auto ret = impl->Build(model_path, static_cast<mindspore::ModelType>(model_type), context);
return static_cast<MSStatus>(ret.StatusCode());
return static_cast<MSStatus>(ret);
}
MSStatus MSModelResize(MSModelHandle model, const MSTensorHandleArray inputs, MSShapeInfo *shape_infos,

View File

@ -115,5 +115,36 @@ std::set<std::string> ProvidersFromMSContext(const mindspore::Context *context)
}
return providers;
}
bool DeviceTypePriority(const lite::Context *context, int device_type1, int device_type2) {
/* dt1 > dt2 true
* dt1 < dt2 false */
if (context == nullptr) {
return false;
}
DeviceContextVector device_infos = context->device_list_;
for (DeviceContext device_info : device_infos) {
if (device_info.device_type_ == device_type1) {
return true;
}
if (device_info.device_type_ == device_type2) {
return false;
}
}
return false;
}
DeviceType KernelArchToDeviceType(kernel::KERNEL_ARCH kernel_arch) {
switch (kernel_arch) {
case kernel::KERNEL_ARCH::kCPU:
return DT_CPU;
case kernel::KERNEL_ARCH::kGPU:
return DT_GPU;
case kernel::KERNEL_ARCH::kNPU:
return DT_NPU;
default:
return DT_CPU;
}
}
} // namespace lite
} // namespace mindspore

View File

@ -21,11 +21,14 @@
#include <string>
#include "include/context.h"
#include "include/api/context.h"
#include "src/lite_kernel.h"
namespace mindspore {
namespace lite {
mindspore::Context *MSContextFromContext(const lite::Context *context);
std::set<std::string> ProvidersFromMSContext(const mindspore::Context *context);
bool DeviceTypePriority(const lite::Context *context, int device_type1, int device_type2);
DeviceType KernelArchToDeviceType(kernel::KERNEL_ARCH kernel_arch);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_COMMON_CONTEXT_UTIL_H_

View File

@ -14,149 +14,127 @@
* limitations under the License.
*/
#include "src/cxx_api/converters.h"
#include <cstddef>
#include <string>
#include <vector>
#include <memory>
#include "include/context.h"
#include "include/api/context.h"
#include "src/runtime/inner_allocator.h"
#include "src/common/log_adapter.h"
namespace mindspore {
constexpr static int kMaxNumOfDevices = 2;
constexpr static int kMaxNumOfDevices = 3;
Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) {
if ((a_context == nullptr) || (l_context == nullptr)) {
MS_LOG(ERROR) << "Invalid context pointers.";
return kLiteNullptr;
}
void ContextUtils::SetContextAttr(int32_t thread_num, bool enable_parallel,
const std::vector<int32_t> &affinity_core_list,
const std::shared_ptr<Delegate> &delegate, lite::InnerContext *inner_context) {
inner_context->thread_num_ = thread_num;
inner_context->enable_parallel_ = enable_parallel;
inner_context->affinity_core_list_ = affinity_core_list;
inner_context->delegate = delegate;
}
auto device_list = a_context->MutableDeviceInfo();
if (device_list.size() == 0) {
MS_LOG(ERROR) << "Invalid device list.";
Status ContextUtils::AddCpuDevice(const std::shared_ptr<Allocator> &allocator, int affinity_mode, bool enable_fp16,
const std::string &provider, const std::string &provider_device,
lite::InnerContext *inner_context) {
inner_context->allocator = allocator;
if (!IsAffinityModeValid(affinity_mode)) {
MS_LOG(ERROR) << "Invalid affinity mode, only supports 0:no affinities, 1:big cores first, 2:little cores first.";
return kLiteInputParamInvalid;
}
if (device_list.size() > kMaxNumOfDevices) {
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
return kLiteInputParamInvalid;
}
l_context->thread_num_ = a_context->GetThreadNum();
l_context->enable_parallel_ = a_context->GetEnableParallel();
l_context->affinity_core_list_ = a_context->GetThreadAffinityCoreList();
l_context->device_list_.clear();
if (device_list[0]->GetDeviceType() != kCPU) {
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
return kLiteInputParamInvalid;
}
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
l_context->allocator = cpu_context->GetAllocator();
if (l_context->allocator == nullptr) {
l_context->allocator = Allocator::Create();
if (l_context->allocator == nullptr) {
MS_LOG(ERROR) << "Create Allocator failed.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "Set new allocator.";
cpu_context->SetAllocator(l_context->allocator);
}
if (!IsAffinityModeValid(a_context->GetThreadAffinityMode())) {
MS_LOG(ERROR)
<< "Invalid affinity mode, only supports 0: no affinities, 1: big cores first, 2: little cores first.";
return kLiteInputParamInvalid;
}
lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->GetThreadAffinityMode());
lite::DeviceInfo cpu_info = {0};
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
l_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
if (device_list.size() == kMaxNumOfDevices) {
lite::DeviceInfo device_info = {0};
if (device_list[1]->GetDeviceType() == kGPU) {
auto gpu_context = device_list[1]->Cast<GPUDeviceInfo>();
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
l_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
device_info.npu_device_info_ = {npu_context->GetFrequency()};
l_context->device_list_.push_back({lite::DT_NPU, device_info});
} else {
MS_LOG(ERROR) << "Invalid device.";
return kLiteInputParamInvalid;
}
}
l_context->delegate = a_context->GetDelegate();
lite::DeviceInfo device_info = {0};
device_info.cpu_device_info_ = {enable_fp16, static_cast<lite::CpuBindMode>(affinity_mode)};
inner_context->device_list_.push_back({lite::DT_CPU, device_info, provider, provider_device, allocator});
return kSuccess;
}
Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context) {
if ((a_context == nullptr) || (l_context == nullptr)) {
MS_LOG(ERROR) << "Invalid context pointers.";
return kLiteNullptr;
}
auto device_list = a_context->device_info_list;
if (device_list.size() == 0) {
MS_LOG(ERROR) << "Invalid device list.";
return kLiteInputParamInvalid;
}
if (device_list.size() > kMaxNumOfDevices) {
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
return kLiteInputParamInvalid;
}
l_context->thread_num_ = a_context->thread_num;
l_context->enable_parallel_ = a_context->enable_parallel;
l_context->affinity_core_list_ = a_context->affinity_core_list;
l_context->device_list_.clear();
if (device_list[0]->device_type != kMSDeviceTypeCPU) {
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
return kLiteInputParamInvalid;
}
auto cpu_context = device_list[0];
l_context->allocator = cpu_context->allocator;
if (l_context->allocator == nullptr) {
l_context->allocator = Allocator::Create();
if (l_context->allocator == nullptr) {
MS_LOG(ERROR) << "Create Allocator failed.";
return kLiteNullptr;
}
MS_LOG(DEBUG) << "Set new allocator.";
cpu_context->allocator = l_context->allocator;
}
if (!IsAffinityModeValid(a_context->affinity_mode)) {
MS_LOG(ERROR)
<< "Invalid affinity mode, only supports 0: no affinities, 1: big cores first, 2: little cores first.";
return kLiteInputParamInvalid;
}
lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->affinity_mode);
lite::DeviceInfo cpu_info = {0};
cpu_info.cpu_device_info_ = {cpu_context->enable_fp16, mode};
l_context->device_list_.push_back(
{lite::DT_CPU, cpu_info, cpu_context->provider, cpu_context->provider_device, cpu_context->allocator});
if (device_list.size() == kMaxNumOfDevices) {
lite::DeviceInfo device_info = {0};
if (device_list[1]->device_type == kMSDeviceTypeGPU) {
auto gpu_context = device_list[1];
device_info.gpu_device_info_ = {gpu_context->enable_fp16};
l_context->device_list_.push_back(
{lite::DT_GPU, device_info, gpu_context->provider, gpu_context->provider_device, gpu_context->allocator});
} else if (device_list[1]->device_type == kMSDeviceTypeKirinNPU) {
auto npu_context = device_list[1];
device_info.npu_device_info_ = {npu_context->frequency};
l_context->device_list_.push_back({lite::DT_NPU, device_info});
} else {
MS_LOG(ERROR) << "Invalid device.";
return kLiteInputParamInvalid;
}
}
l_context->delegate = a_context->delegate;
Status ContextUtils::AddGpuDevice(bool enable_fp16, const std::string &provider, const std::string &provider_device,
const std::shared_ptr<Allocator> &allocator, lite::InnerContext *inner_context) {
lite::DeviceInfo device_info = {0};
device_info.gpu_device_info_ = {enable_fp16};
inner_context->device_list_.push_back({lite::DT_GPU, device_info, provider, provider_device, allocator});
return kSuccess;
}
Status ContextUtils::AddNpuDevice(int frequency, lite::InnerContext *inner_context) {
lite::DeviceInfo device_info = {0};
device_info.npu_device_info_ = {frequency};
inner_context->device_list_.push_back({lite::DT_NPU, device_info});
return kSuccess;
}
lite::InnerContext *ContextUtils::Convert(Context *context) {
auto inner_context = std::make_unique<lite::InnerContext>();
if ((context == nullptr) || (inner_context == nullptr)) {
MS_LOG(ERROR) << "Invalid context pointers.";
return nullptr;
}
auto device_list = context->MutableDeviceInfo();
if (device_list.size() == 0 || device_list.size() > kMaxNumOfDevices) {
MS_LOG(ERROR) << "Device num, support min: 1, max: " << kMaxNumOfDevices;
return nullptr;
}
SetContextAttr(context->GetThreadNum(), context->GetEnableParallel(), context->GetThreadAffinityCoreList(),
context->GetDelegate(), inner_context.get());
inner_context->device_list_.clear();
Status ret = kLiteError;
for (auto &device : device_list) {
if (device == nullptr) {
return nullptr;
}
if (device->GetDeviceType() == kCPU) {
auto cpu_context = device->Cast<CPUDeviceInfo>();
if (cpu_context->GetAllocator() == nullptr) {
cpu_context->SetAllocator(Allocator::Create());
}
ret = AddCpuDevice(cpu_context->GetAllocator(), context->GetThreadAffinityMode(), cpu_context->GetEnableFP16(),
cpu_context->GetProvider(), cpu_context->GetProviderDevice(), inner_context.get());
} else if (device->GetDeviceType() == kGPU) {
auto gpu_context = device->Cast<GPUDeviceInfo>();
ret = AddGpuDevice(gpu_context->GetEnableFP16(), gpu_context->GetProvider(), gpu_context->GetProviderDevice(),
gpu_context->GetAllocator(), inner_context.get());
} else if (device->GetDeviceType() == kKirinNPU) {
auto npu_context = device->Cast<KirinNPUDeviceInfo>();
ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get());
}
if (ret != kSuccess) {
MS_LOG(ERROR) << "Add device failed!";
return nullptr;
}
}
return inner_context.release();
}
lite::InnerContext *ContextUtils::Convert(const ContextC *context_c) {
auto inner_context = std::make_unique<lite::InnerContext>();
if ((context_c == nullptr) || (inner_context == nullptr)) {
MS_LOG(ERROR) << "Invalid context pointers.";
return nullptr;
}
auto device_list = context_c->device_info_list;
if (device_list.size() == 0 || device_list.size() > kMaxNumOfDevices) {
MS_LOG(ERROR) << "Device num, support min: 1, max: " << kMaxNumOfDevices;
return nullptr;
}
SetContextAttr(context_c->thread_num, context_c->enable_parallel, context_c->affinity_core_list, context_c->delegate,
inner_context.get());
inner_context->device_list_.clear();
Status ret = kLiteError;
for (auto &device_info_c : device_list) {
if (device_info_c == nullptr) {
return nullptr;
}
if (device_info_c->device_type == kMSDeviceTypeCPU) {
if (device_info_c->allocator == nullptr) {
device_info_c->allocator = Allocator::Create();
}
ret = AddCpuDevice(device_info_c->allocator, context_c->affinity_mode, device_info_c->enable_fp16,
device_info_c->provider, device_info_c->provider_device, inner_context.get());
} else if (device_info_c->device_type == kMSDeviceTypeGPU) {
ret = AddGpuDevice(device_info_c->enable_fp16, device_info_c->provider, device_info_c->provider_device,
device_info_c->allocator, inner_context.get());
} else if (device_info_c->device_type == kMSDeviceTypeKirinNPU) {
ret = AddNpuDevice(device_info_c->frequency, inner_context.get());
}
if (ret != kSuccess) {
MS_LOG(ERROR) << "Add device failed!";
return nullptr;
}
}
return inner_context.release();
}
} // namespace mindspore

View File

@ -13,26 +13,38 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CXX_API_CONVERTERS_H_
#define MINDSPORE_LITE_SRC_CXX_API_CONVERTERS_H_
#include <limits.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/context.h"
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/lite_types.h"
#include "src/cxx_api/context.h"
#include "include/api/cfg.h"
#include "include/train/train_cfg.h"
#include "src/inner_context.h"
#include "src/c_api/context_c.h"
namespace mindspore {
class ContextUtils {
public:
static lite::InnerContext *Convert(Context *context);
static lite::InnerContext *Convert(const ContextC *context_c);
namespace lite {
struct Context;
class TrainCfg;
} // namespace lite
class Context;
class TrainCfg;
private:
static void SetContextAttr(int32_t thread_num, bool enable_parallel, const std::vector<int32_t> &affinity_core_list,
const std::shared_ptr<Delegate> &delegate, lite::InnerContext *inner_context);
static Status AddCpuDevice(const std::shared_ptr<Allocator> &allocator, int affinity_mode, bool enable_fp16,
const std::string &provider, const std::string &provider_device,
lite::InnerContext *inner_context);
static Status AddGpuDevice(bool enable_fp16, const std::string &provider, const std::string &provider_device,
const std::shared_ptr<Allocator> &allocator, lite::InnerContext *inner_context);
static Status AddNpuDevice(int frequency, lite::InnerContext *inner_context);
static bool IsAffinityModeValid(int affinity_mode) {
return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU;
}
};
inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) {
if (qt == kNoQuant) {
@ -44,25 +56,6 @@ inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) {
return lite::QT_DEFAULT;
}
inline lite::CpuBindMode A2L_ConvertAffinityMode(int affinity_mode) {
switch (affinity_mode) {
case 0:
return lite::NO_BIND;
case 1:
return lite::HIGHER_CPU;
case 2:
return lite::MID_CPU;
default:
return lite::NO_BIND;
}
}
inline bool IsAffinityModeValid(int affinity_mode) {
return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU;
}
Status A2L_ConvertContext(Context *a_context, lite::Context *l_context);
Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context);
Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg);
} // namespace mindspore

View File

@ -45,19 +45,16 @@ CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProt
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &ms_context) {
context_ = ms_context;
lite::Context lite_context;
auto status = A2L_ConvertContext(ms_context.get(), &lite_context);
if (status != kSuccess) {
return status;
}
auto session = std::shared_ptr<session::LiteSession>(
session::LiteSession::CreateSession(static_cast<const char *>(model_data), data_size, &lite_context));
auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(ContextUtils::Convert(ms_context.get())));
if (session == nullptr) {
MS_LOG(ERROR) << "Allocate session failed.";
return kLiteNullptr;
}
auto ret = session->LoadModelAndCompileByBuf(static_cast<const char *>(model_data), data_size);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init session failed";
return kLiteError;
}
session_.swap(session);
MS_LOG(DEBUG) << "Build model success.";
return kSuccess;
@ -65,18 +62,16 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &ms_context) {
lite::Context lite_context;
auto status = A2L_ConvertContext(ms_context.get(), &lite_context);
if (status != kSuccess) {
return status;
}
auto session = std::shared_ptr<session::LiteSession>(lite::LiteSession::CreateSession(model_path, &lite_context));
auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(ContextUtils::Convert(ms_context.get())));
if (session == nullptr) {
MS_LOG(ERROR) << "Allocate session failed.";
return kLiteNullptr;
}
auto ret = session->LoadModelAndCompileByPath(model_path);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init session failed";
return kLiteError;
}
session_.swap(session);
MS_LOG(DEBUG) << "Build model success.";
return kSuccess;
@ -94,16 +89,15 @@ Status ModelImpl::Build() {
return kLiteNullptr;
}
lite::Context model_context;
auto status = A2L_ConvertContext(context_.get(), &model_context);
if (status != kSuccess) {
auto *inner_context = ContextUtils::Convert(context_.get());
if (inner_context == nullptr) {
MS_LOG(ERROR) << "Failed to convert Context to Lite Context";
return status;
return kLiteNullptr;
}
auto create_callback = CreateTrainSessionCallbackHolder();
if (create_callback != nullptr) {
auto session = create_callback(graph_->graph_data_, cfg_, &model_context);
auto session = create_callback(graph_->graph_data_, cfg_, inner_context);
if (session != nullptr) {
session_ = session;
MS_LOG(DEBUG) << "Build model success.";
@ -113,10 +107,11 @@ Status ModelImpl::Build() {
auto model = graph_->graph_data_->lite_model();
if (model == nullptr || model->buf == nullptr) {
delete inner_context;
MS_LOG(ERROR) << "Lite model has been freed.";
return kLiteError;
}
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&model_context));
auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(inner_context));
if (session == nullptr) {
MS_LOG(ERROR) << "Allocate session failed.";
return kLiteNullptr;
@ -436,4 +431,21 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<
auto ret = session_->Resize(inner_input, truncated_shape);
return static_cast<StatusCode>(ret);
}
lite::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) {
auto session = new (std::nothrow) lite::LiteSession();
if (session == nullptr) {
MS_LOG(ERROR) << "create session failed";
delete context;
return nullptr;
}
auto ret = session->Init(context);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init session failed";
delete session;
return nullptr;
}
return session;
}
} // namespace mindspore

View File

@ -29,6 +29,8 @@
#include "include/api/cell.h"
#include "include/lite_session.h"
#include "src/cxx_api/graph/graph_data.h"
#include "src/inner_context.h"
#include "src/lite_session.h"
template <class T>
void clearVectorOfPointers(std::vector<T> *v) {
@ -42,9 +44,9 @@ void clearVectorOfPointers(std::vector<T> *v) {
namespace mindspore {
typedef std::shared_ptr<session::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data,
std::shared_ptr<TrainCfg> cfg,
lite::Context *context);
typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data,
std::shared_ptr<TrainCfg> cfg,
lite::InnerContext *context);
CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
namespace session {
@ -65,6 +67,7 @@ class ModelImpl {
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
const MSKernelCallBack &after);
lite::LiteSession *CreateLiteSession(lite::InnerContext *context);
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
@ -81,6 +84,7 @@ class ModelImpl {
return kSuccess;
}
std::vector<Metrics *> GetMetrics() { return metrics_; }
const session::LiteSession *GetSession() const { return session_.get(); }
protected:
// Utility methods
@ -94,7 +98,7 @@ class ModelImpl {
friend class Model;
friend class Serialization;
std::shared_ptr<Graph> graph_ = nullptr;
std::shared_ptr<session::LiteSession> session_ = nullptr;
std::shared_ptr<lite::LiteSession> session_ = nullptr;
std::shared_ptr<Context> context_ = nullptr;
std::shared_ptr<TrainCfg> cfg_ = nullptr;
std::vector<Metrics *> metrics_;

View File

@ -25,6 +25,7 @@
#include "include/api/callback/callback.h"
#include "include/api/metrics/metrics.h"
#include "src/lite_model.h"
#include "src/inner_context.h"
#include "src/runtime/inner_allocator.h"
#include "src/common/string_util.h"
#include "src/cxx_api/model/model_impl.h"
@ -40,8 +41,8 @@
#include "src/train/train_session.h"
namespace mindspore {
std::shared_ptr<session::LiteSession> CreateTrainSession(std::shared_ptr<Graph::GraphData> graph_data,
std::shared_ptr<TrainCfg> cfg, lite::Context *context) {
std::shared_ptr<lite::LiteSession> CreateTrainSession(std::shared_ptr<Graph::GraphData> graph_data,
std::shared_ptr<TrainCfg> cfg, lite::InnerContext *context) {
bool is_train_session = graph_data->IsTrainModel();
if (is_train_session) {
auto model = graph_data->lite_model();
@ -49,8 +50,8 @@ std::shared_ptr<session::LiteSession> CreateTrainSession(std::shared_ptr<Graph::
MS_LOG(ERROR) << "Lite model has been freed.";
return nullptr;
}
std::shared_ptr<session::LiteSession> shared_session;
lite::TrainSession *session = new lite::TrainSession();
std::shared_ptr<lite::LiteSession> shared_session;
auto session = new (std::nothrow) lite::TrainSession();
if (session == nullptr) {
MS_LOG(ERROR) << "create session failed";
return nullptr;

View File

@ -17,7 +17,6 @@
#include <algorithm>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#ifdef SUPPORT_NPU
#include "include/HiAiModelManagerType.h"
#endif
@ -28,6 +27,8 @@
namespace mindspore::lite {
namespace {
constexpr int kDefaultParallelNum = 2;
constexpr int kMaxLiteContextDeviceNums = 2;
constexpr int kMaxInnerContextDeviceNums = 3;
} // namespace
InnerContext::InnerContext(const Context *context) {
@ -46,24 +47,54 @@ InnerContext::InnerContext(const Context *context) {
}
void InnerContext::SetContextDevice(const Context *context) {
this->device_list_.clear();
if (context->device_list_.size() > kMaxLiteContextDeviceNums || context->device_list_.size() <= 0) {
return;
}
if (context->device_list_.front().device_type_ != DT_CPU) {
return;
}
/* user set order for different device */
if (context->device_list_.size() < kMaxLiteContextDeviceNums) {
this->device_list_.push_back(context->device_list_.front());
return;
}
/* keep compatibility :
* if user set CPU & NPU/GPU
* NPU/GPU higher priority */
bool isUserSetNPU = context->device_list_.end() !=
std::find_if(context->device_list_.begin(), context->device_list_.end(),
std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_NPU; });
bool isUserSetGPU = context->device_list_.end() !=
std::find_if(context->device_list_.begin(), context->device_list_.end(),
std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_GPU; });
this->device_list_.clear();
if (isUserSetGPU == false && isUserSetNPU == false) {
return;
}
/* add GPU/NPU first */
for (auto &device_ctx : context->device_list_) {
// npu/gpu server would use one core so we don't bind core to avoid competition.
// If user does not set npu/gpu device, we still bind core.
if (device_ctx.device_type_ == DT_CPU && (isUserSetNPU || (isUserSetGPU && !enable_parallel_))) {
auto cpu_ctx = device_ctx;
cpu_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
this->device_list_.push_back(cpu_ctx);
} else {
if (device_ctx.device_type_ != DT_CPU) {
this->device_list_.push_back(device_ctx);
}
}
/* add CPU */
for (auto &device_ctx : context->device_list_) {
if (device_ctx.device_type_ == DT_CPU) {
if (isUserSetNPU || (isUserSetGPU && enable_parallel_ == false)) {
auto cpu_ctx = device_ctx;
cpu_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
this->device_list_.push_back(cpu_ctx);
} else {
this->device_list_.push_back(device_ctx);
}
}
}
return;
}
int InnerContext::Init() {
@ -71,12 +102,11 @@ int InnerContext::Init() {
MS_LOG(ERROR) << "Context is not valid";
return RET_NOT_SUPPORT;
}
if (this->thread_pool_ == nullptr && this->IsCpuEnabled()) {
if (this->thread_pool_ == nullptr) {
int actor_parallel_thread = this->enable_parallel_ ? kDefaultParallelNum : 1;
if (this->affinity_core_list_.empty()) {
auto bind_mode = static_cast<BindMode>(this->device_list_.front().device_info_.cpu_device_info_.cpu_bind_mode_);
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_, bind_mode);
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_,
static_cast<BindMode>(GetCpuBindMode()));
if (thread_pool_ == nullptr) {
MS_LOG(ERROR) << "Create ThreadPool failed";
return RET_NULL_PTR;
@ -131,8 +161,8 @@ int InnerContext::IsValid() const {
MS_LOG(ERROR) << "Device list is empty.";
return RET_NOT_SUPPORT;
}
if (this->device_list_.size() > kMaxDeviceNums) {
MS_LOG(ERROR) << "Not support device list more than 2.";
if (this->device_list_.size() > kMaxInnerContextDeviceNums) {
MS_LOG(ERROR) << "Not support device list more than " << kMaxInnerContextDeviceNums;
return RET_NOT_SUPPORT;
}
if (thread_num_ < 1) {
@ -144,11 +174,6 @@ int InnerContext::IsValid() const {
return RET_NOT_SUPPORT;
}
if (!IsUserSetCpu()) {
MS_LOG(ERROR) << "CPU context should be set.";
return RET_NOT_SUPPORT;
}
if (IsCpuBindModeInvalid()) {
MS_LOG(ERROR) << "CPU bind mode should be one of NO_BIND, HIGHER_CPU or MID_CPU.";
return RET_NOT_SUPPORT;
@ -196,6 +221,13 @@ bool InnerContext::IsGpuFloat16Enabled() const {
bool InnerContext::IsCpuEnabled() const { return IsUserSetCpu(); }
int InnerContext::GetCpuBindMode() const {
auto iter = std::find_if(device_list_.begin(), device_list_.end(), [](const DeviceContext &device) {
return (device.device_type_ == DeviceType::DT_CPU) ? true : false;
});
return iter != device_list_.end() ? iter->device_info_.cpu_device_info_.cpu_bind_mode_ : NO_BIND;
}
bool InnerContext::IsGpuEnabled() const {
#ifdef SUPPORT_GPU
return IsUserSetGpu();

View File

@ -26,7 +26,6 @@
#endif
namespace mindspore::lite {
const constexpr int kMaxDeviceNums = 2;
struct InnerContext : public Context {
public:
InnerContext() = default;
@ -41,6 +40,8 @@ struct InnerContext : public Context {
bool IsCpuEnabled() const;
int GetCpuBindMode() const;
bool IsGpuEnabled() const;
bool IsNpuEnabled() const;
@ -82,7 +83,6 @@ struct InnerContext : public Context {
};
int ParallelLaunch(const Context *context, const Func &func, Content content, int task_num);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_INNER_CONTEXT_H

View File

@ -655,7 +655,7 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af
return ret;
}
int LiteSession::Init(const Context *context) {
int LiteSession::Init(InnerContext *context) {
bool expected = false;
if (!is_running_.compare_exchange_strong(expected, true)) {
MS_LOG(ERROR) << "Not support multi-threading";
@ -666,20 +666,22 @@ int LiteSession::Init(const Context *context) {
is_running_.store(false);
return RET_NULL_PTR;
}
this->context_ = new (std::nothrow) InnerContext(context);
if (this->context_ == nullptr) {
MS_LOG(ERROR) << "New Context failed";
is_running_.store(false);
return RET_MEMORY_FAILED;
}
this->context_ = context;
auto ret = this->context_->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init Context failed";
is_running_.store(false);
return ret;
}
if (context->delegate != nullptr) {
delegate_ = context->delegate;
ms_context_ = MSContextFromContext(context_);
if (ms_context_ == nullptr) {
MS_LOG(ERROR) << "transfer context to ms context failed.";
is_running_.store(false);
return RET_NULL_PTR;
}
if (context_->delegate != nullptr) {
delegate_ = context_->delegate;
}
#if SUPPORT_NPU
if (delegate_ == nullptr && context_->IsNpuEnabled()) {
@ -688,7 +690,6 @@ int LiteSession::Init(const Context *context) {
MS_LOG(ERROR) << "New delegate_ failed";
return RET_ERROR;
}
this->context_->delegate = delegate_;
}
#endif
#if GPU_TENSORRT
@ -698,7 +699,6 @@ int LiteSession::Init(const Context *context) {
MS_LOG(ERROR) << "New tensorrt delegate_ failed";
return RET_ERROR;
}
this->context_->delegate = delegate_;
}
#endif
if (delegate_ != nullptr) {
@ -720,12 +720,6 @@ int LiteSession::Init(const Context *context) {
is_running_.store(false);
return ret;
}
ms_context_ = MSContextFromContext(context);
if (ms_context_ == nullptr) {
MS_LOG(ERROR) << "transfer context to ms context failed.";
is_running_.store(false);
return RET_NULL_PTR;
}
is_running_.store(false);
return RET_OK;
}
@ -921,14 +915,15 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
}
int LiteSession::InitGPURuntime() {
CpuBindMode cpu_bind_mode = this->context_->device_list_.front().device_info_.cpu_device_info_.cpu_bind_mode_;
ActorThreadPool *thread_pool = this->context_->thread_pool();
if (thread_pool == nullptr) {
MS_LOG(ERROR) << "thread pool is nullptr";
is_running_.store(false);
return RET_NULL_PTR;
if (context_->IsCpuEnabled()) {
ActorThreadPool *thread_pool = this->context_->thread_pool();
if (thread_pool == nullptr) {
MS_LOG(ERROR) << "thread pool is nullptr";
is_running_.store(false);
return RET_NULL_PTR;
}
thread_pool->SetProcessAffinity(static_cast<BindMode>(static_cast<BindMode>(context_->GetCpuBindMode())));
}
thread_pool->SetProcessAffinity(static_cast<BindMode>(cpu_bind_mode));
#if GPU_OPENCL
if (this->context_->IsGpuEnabled()) {
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper();
@ -971,7 +966,10 @@ int LiteSession::InitGPURuntime() {
}
#endif
// Setting the binding core will affect the opencl drive scheduling.
thread_pool->SetProcessAffinity(static_cast<BindMode>(NO_BIND));
if (context_->IsCpuEnabled()) {
ActorThreadPool *thread_pool = this->context_->thread_pool();
thread_pool->SetProcessAffinity(static_cast<BindMode>(NO_BIND));
}
return RET_OK;
}
} // namespace lite
@ -982,7 +980,7 @@ session::LiteSession *session::LiteSession::CreateSession(const lite::Context *c
MS_LOG(ERROR) << "create session failed";
return nullptr;
}
auto ret = session->Init(context);
auto ret = session->Init(new (std::nothrow) mindspore::lite::InnerContext(context));
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init session failed";
delete session;
@ -998,52 +996,67 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf,
MS_LOG(ERROR) << "Create session failed";
return nullptr;
}
auto *model = lite::ImportFromBuffer(model_buf, size, true);
if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed";
auto ret = reinterpret_cast<lite::LiteSession *>(session)->LoadModelAndCompileByBuf(model_buf, size);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init session failed";
delete session;
return nullptr;
}
auto ret = session->CompileGraph(model);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Compile model failed";
model->buf = nullptr;
delete model;
delete session;
return nullptr;
}
model->buf = nullptr;
(reinterpret_cast<lite::LiteSession *>(session))->set_model(model);
return session;
}
session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_path, const lite::Context *context) {
size_t model_size;
auto model_buf = lite::ReadFile(model_path.c_str(), &model_size);
if (model_buf == nullptr) {
MS_LOG(ERROR) << "Read model file failed";
return nullptr;
}
auto *session = session::LiteSession::CreateSession(context);
if (session == nullptr) {
MS_LOG(ERROR) << "Create session failed";
return nullptr;
}
auto *model = lite::ImportFromBuffer(model_buf, model_size, true);
if (model == nullptr) {
auto ret = reinterpret_cast<lite::LiteSession *>(session)->LoadModelAndCompileByPath(model_path);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init session failed";
delete session;
MS_LOG(ERROR) << "Import model failed";
return nullptr;
}
(reinterpret_cast<lite::LiteModel *>(model))->set_keep_model_buf(true);
auto ret = session->CompileGraph(model);
if (ret != lite::RET_OK) {
delete model;
delete session;
MS_LOG(ERROR) << "Compile model failed";
return nullptr;
}
(reinterpret_cast<lite::LiteSession *>(session))->set_model(model);
return session;
}
int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, size_t buf_size) {
auto *model = lite::ImportFromBuffer(model_buf, buf_size, true);
if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed";
return RET_ERROR;
}
auto ret = CompileGraph(model);
model->buf = nullptr;
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Compile model failed";
delete model;
return RET_ERROR;
}
set_model(model);
return RET_OK;
}
int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path) {
size_t model_size;
auto model_buf = lite::ReadFile(model_path.c_str(), &model_size);
if (model_buf == nullptr) {
MS_LOG(ERROR) << "Read model file failed";
return RET_ERROR;
}
auto *model = lite::ImportFromBuffer(model_buf, model_size, true);
if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed";
return RET_ERROR;
}
(reinterpret_cast<lite::LiteModel *>(model))->set_keep_model_buf(true);
auto ret = CompileGraph(model);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Compile model failed";
delete model;
return RET_ERROR;
}
set_model(model);
return RET_OK;
}
} // namespace mindspore

View File

@ -49,7 +49,10 @@ class LiteSession : public session::LiteSession {
static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
virtual int Init(const Context *context);
int LoadModelAndCompileByBuf(const char *model_buf, size_t buf_size);
int LoadModelAndCompileByPath(const std::string &model_path);
virtual int Init(InnerContext *context);
void BindThread(bool if_bind) override;

View File

@ -33,6 +33,7 @@
#include "src/common/version_manager.h"
#include "src/common/prim_util.h"
#include "src/common/tensor_util.h"
#include "src/common/context_util.h"
#include "src/runtime/infer_manager.h"
#include "src/sub_graph_split.h"
#include "src/weight_decoder.h"
@ -234,6 +235,19 @@ int Scheduler::InitKernels(std::vector<kernel::LiteKernel *> dst_kernels) {
return RET_OK;
}
int Scheduler::CheckCpuValid(const std::vector<kernel::LiteKernel *> *dst_kernels) const {
if (context_->IsCpuEnabled()) {
return RET_OK;
}
for (auto kernel : *dst_kernels) {
if (kernel->desc().arch == kernel::KERNEL_ARCH::kCPU) {
MS_LOG(ERROR) << "kernel: " << kernel->name() << " only support in CPU.";
return RET_ERROR;
}
}
return RET_OK;
}
int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
if (dst_kernels == nullptr) {
return RET_ERROR;
@ -270,16 +284,19 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
MS_LOG(ERROR) << "Schedule graph to kernels failed.";
return ret;
}
SetSubgraphForPartialNode();
if (delegate_ != nullptr) {
ret = ReplaceDelegateKernels(dst_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Repalce delegate kernels failed.";
return ret;
}
context_->thread_pool()->SetMaxSpinCount(kMinSpinCount);
ret = InitDelegateKernels(dst_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Repalce delegate kernels failed.";
return ret;
}
ret = CheckCpuValid(dst_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "kernels invalid in set devices.";
return ret;
}
FindAllInoutKernels(*dst_kernels);
if (IsControlFlowParttern(*dst_kernels)) {
@ -378,6 +395,55 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_ker
return RET_OK;
}
int Scheduler::InitDelegateKernels(std::vector<kernel::LiteKernel *> *dst_kernels) {
/* no delegate valid */
if (delegate_ == nullptr) {
return RET_OK;
}
/* set delegate spin count */
context_->thread_pool()->SetMaxSpinCount(kMinSpinCount);
/* Inner delegate : check Priority */
std::vector<kernel::LiteKernel *> src_kernels = *dst_kernels;
dst_kernels->clear();
while (!src_kernels.empty()) {
std::vector<kernel::LiteKernel *> tmp_kernels;
kernel::LiteKernel *remain_kernel = nullptr;
/* Loop for inner delegate npu and TensorRT subgraph */
while (!src_kernels.empty()) {
auto kernel = src_kernels.front();
VectorErase(&src_kernels, kernel);
bool priority_ret = DeviceTypePriority(context_, DT_NPU, KernelArchToDeviceType(kernel->desc().arch));
if (priority_ret == true) {
tmp_kernels.push_back(kernel);
} else {
remain_kernel = kernel;
break;
}
}
/* start current NPU-kernels replace */
if (tmp_kernels.empty()) {
if (remain_kernel != nullptr) {
dst_kernels->push_back(remain_kernel);
remain_kernel = nullptr;
}
continue;
}
auto ret = ReplaceDelegateKernels(&tmp_kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "NPU delegate repalce delegate kernels failed.";
return ret;
}
dst_kernels->insert(dst_kernels->end(), tmp_kernels.begin(), tmp_kernels.end());
tmp_kernels.clear();
if (remain_kernel != nullptr) {
dst_kernels->push_back(remain_kernel);
remain_kernel = nullptr;
}
}
return RET_OK;
}
void Scheduler::FindNodeInoutTensors(const lite::Model::Node &node, std::vector<Tensor *> *inputs,
std::vector<Tensor *> *outputs) {
MS_ASSERT(inputs != nullptr);
@ -914,7 +980,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
op_parameter->is_train_session_ = is_train_session_;
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
#ifdef GPU_OPENCL
if (node->device_type_ == DT_GPU || node->device_type_ == kDefaultDeviceType) {
bool gpu_priority = DeviceTypePriority(context_, DT_GPU, DT_CPU);
bool use_gpu_kernel = node->device_type_ == DT_GPU || node->device_type_ == kDefaultDeviceType;
if (gpu_priority && use_gpu_kernel) {
status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel);
if (status == RET_OK) {
return kernel;

View File

@ -69,12 +69,15 @@ class Scheduler {
int FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
kernel::LiteKernel **kernel);
int CheckCpuValid(const std::vector<kernel::LiteKernel *> *dst_kernels) const;
int FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
int FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel);
int ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_kernels);
int InitDelegateKernels(std::vector<kernel::LiteKernel *> *dst_kernels);
int InitKernels(std::vector<kernel::LiteKernel *> dst_kernels);
kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node);
// schedule a partial node to a subgraph_kernel

View File

@ -51,7 +51,7 @@ TrainSession::TrainSession() {
InitCallBack();
}
int TrainSession::Init(const Context *context, const TrainCfg *train_cfg) {
int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) {
if (train_cfg != nullptr) {
if (train_cfg->mix_precision_cfg_.loss_scale_ <= 0) {
MS_LOG(ERROR) << "illegal loss scale configuration";
@ -769,7 +769,7 @@ session::LiteSession *session::TrainSession::CreateTrainSession(const std::strin
return nullptr;
}
auto ret = session->Init(context, cfg);
auto ret = session->Init(new (std::nothrow) mindspore::lite::InnerContext(context), cfg);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init session failed";
return nullptr;

View File

@ -54,7 +54,7 @@ class TrainSession : virtual public lite::LiteSession {
int CompileGraph(lite::Model *model) override;
virtual int CompileTrainGraph(std::shared_ptr<Model> model);
virtual int Init(const Context *context, const TrainCfg *train_cfg);
virtual int Init(InnerContext *context, const TrainCfg *train_cfg);
int Train() override;
int Eval() override;

View File

@ -248,7 +248,7 @@ static session::LiteSession *CreateTransferSessionInt(const char *model_buf_back
return nullptr;
}
auto ret = session->Init(context, cfg);
auto ret = session->Init(new (std::nothrow) mindspore::lite::InnerContext(context), cfg);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "init transfer session failed";
delete session;

View File

@ -97,21 +97,20 @@ int BenchmarkCApi::InitContext() {
MSContextSetEnableParallel(context_, flags_->enable_parallel_);
MSContextSetThreadAffinityMode(context_, flags_->cpu_bind_mode_);
MSDeviceInfoHandle cpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeCPU);
MSDeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_);
MSContextAddDeviceInfo(context_, cpu_device_info);
if (flags_->device_ == "GPU") {
MSDeviceInfoHandle gpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeGPU);
MSDeviceInfoSetEnableFP16(gpu_device_info, flags_->enable_fp16_);
MSContextAddDeviceInfo(context_, gpu_device_info);
}
if (flags_->device_ == "NPU") {
MSDeviceInfoHandle npu_device_info = MSDeviceInfoCreate(kMSDeviceTypeKirinNPU);
MSDeviceInfoSetFrequency(npu_device_info, kFrequencyDefault);
MSContextAddDeviceInfo(context_, npu_device_info);
}
MSDeviceInfoHandle cpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeCPU);
MSDeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_);
MSContextAddDeviceInfo(context_, cpu_device_info);
return RET_OK;
}

View File

@ -127,10 +127,6 @@ void BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context
context->SetThreadAffinity(flags_->cpu_bind_mode_);
auto &device_list = context->MutableDeviceInfo();
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetEnableFP16(flags_->enable_fp16_);
device_list.push_back(device_info);
if (flags_->device_ == "GPU") {
std::shared_ptr<GPUDeviceInfo> gpu_device_info = std::make_shared<GPUDeviceInfo>();
gpu_device_info->SetEnableFP16(flags_->enable_fp16_);
@ -142,6 +138,11 @@ void BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context
npu_device_info->SetFrequency(kFrequencyDefault);
device_list.push_back(npu_device_info);
}
// CPU priority is behind GPU and NPU
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetEnableFP16(flags_->enable_fp16_);
device_list.push_back(device_info);
}
int BenchmarkUnifiedApi::CompareOutput() {