forked from mindspore-Ecosystem/mindspore
fix device priority
This commit is contained in:
parent
19dd2e21b4
commit
62230a0e78
|
@ -32,8 +32,8 @@ class ModelC {
|
|||
}
|
||||
}
|
||||
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context);
|
||||
Status Build(const std::string &model_path, ModelType model_type, const ContextC *model_context);
|
||||
int Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context);
|
||||
int Build(const std::string &model_path, ModelType model_type, const ContextC *model_context);
|
||||
Status Resize(const std::vector<MSTensor::Impl *> &inputs, const std::vector<std::vector<int64_t>> &shapes);
|
||||
|
||||
Status Predict(const MSTensorHandle *inputs, size_t input_num, MSTensorHandle **outputs, size_t *output_num,
|
||||
|
@ -45,45 +45,41 @@ class ModelC {
|
|||
MSTensor::Impl *GetOutputByTensorName(const std::string &name);
|
||||
|
||||
private:
|
||||
std::shared_ptr<session::LiteSession> session_ = nullptr;
|
||||
std::shared_ptr<const ContextC> context_ = nullptr;
|
||||
std::unique_ptr<lite::LiteSession> session_ = nullptr;
|
||||
std::unique_ptr<const ContextC> context_ = nullptr;
|
||||
std::map<mindspore::tensor::MSTensor *, MSTensor::Impl *> tensor_map_;
|
||||
std::vector<MSTensor::Impl *> inputs_;
|
||||
std::vector<MSTensor::Impl *> outputs_;
|
||||
int CreateLiteSession(const ContextC *context);
|
||||
Status RunGraph(const MSKernelCallBackC &before, const MSKernelCallBackC &after);
|
||||
void ResetTensorData(std::vector<void *> old_data, std::vector<tensor::MSTensor *> tensors);
|
||||
MSTensor::Impl *TensorToTensorImpl(mindspore::tensor::MSTensor *tensor);
|
||||
};
|
||||
|
||||
Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) {
|
||||
context_.reset(model_context);
|
||||
lite::Context lite_context;
|
||||
auto status = A2L_ConvertContext(model_context, &lite_context);
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
int ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) {
|
||||
int ret = CreateLiteSession(model_context);
|
||||
if (ret != lite::RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
session_ = std::shared_ptr<session::LiteSession>(
|
||||
session::LiteSession::CreateSession(static_cast<const char *>(model_data), data_size, &lite_context));
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
return kSuccess;
|
||||
return session_->LoadModelAndCompileByBuf(static_cast<const char *>(model_data), data_size);
|
||||
}
|
||||
|
||||
Status ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) {
|
||||
context_.reset(model_context);
|
||||
lite::Context lite_context;
|
||||
auto status = A2L_ConvertContext(model_context, &lite_context);
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
int ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) {
|
||||
int ret = CreateLiteSession(model_context);
|
||||
if (ret != lite::RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
session_ = std::shared_ptr<session::LiteSession>(lite::LiteSession::CreateSession(model_path, &lite_context));
|
||||
return session_->LoadModelAndCompileByPath(model_path);
|
||||
}
|
||||
|
||||
int ModelC::CreateLiteSession(const ContextC *context) {
|
||||
context_.reset(context);
|
||||
session_.reset(new (std::nothrow) lite::LiteSession());
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
return kLiteError;
|
||||
MS_LOG(ERROR) << "create session failed";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
return kSuccess;
|
||||
return session_->Init(ContextUtils::Convert(context_.get()));
|
||||
}
|
||||
|
||||
Status ModelC::Resize(const std::vector<MSTensor::Impl *> &inputs, const std::vector<std::vector<int64_t>> &shapes) {
|
||||
|
@ -332,7 +328,7 @@ MSStatus MSModelBuild(MSModelHandle model, const void *model_data, size_t data_s
|
|||
mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
auto ret = impl->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), context);
|
||||
return static_cast<MSStatus>(ret.StatusCode());
|
||||
return static_cast<MSStatus>(ret);
|
||||
}
|
||||
|
||||
MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSModelType model_type,
|
||||
|
@ -348,7 +344,7 @@ MSStatus MSModelBuildFromFile(MSModelHandle model, const char *model_path, MSMod
|
|||
mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
|
||||
auto impl = static_cast<mindspore::ModelC *>(model);
|
||||
auto ret = impl->Build(model_path, static_cast<mindspore::ModelType>(model_type), context);
|
||||
return static_cast<MSStatus>(ret.StatusCode());
|
||||
return static_cast<MSStatus>(ret);
|
||||
}
|
||||
|
||||
MSStatus MSModelResize(MSModelHandle model, const MSTensorHandleArray inputs, MSShapeInfo *shape_infos,
|
||||
|
|
|
@ -115,5 +115,36 @@ std::set<std::string> ProvidersFromMSContext(const mindspore::Context *context)
|
|||
}
|
||||
return providers;
|
||||
}
|
||||
|
||||
bool DeviceTypePriority(const lite::Context *context, int device_type1, int device_type2) {
|
||||
/* dt1 > dt2 true
|
||||
* dt1 < dt2 false */
|
||||
if (context == nullptr) {
|
||||
return false;
|
||||
}
|
||||
DeviceContextVector device_infos = context->device_list_;
|
||||
for (DeviceContext device_info : device_infos) {
|
||||
if (device_info.device_type_ == device_type1) {
|
||||
return true;
|
||||
}
|
||||
if (device_info.device_type_ == device_type2) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
DeviceType KernelArchToDeviceType(kernel::KERNEL_ARCH kernel_arch) {
|
||||
switch (kernel_arch) {
|
||||
case kernel::KERNEL_ARCH::kCPU:
|
||||
return DT_CPU;
|
||||
case kernel::KERNEL_ARCH::kGPU:
|
||||
return DT_GPU;
|
||||
case kernel::KERNEL_ARCH::kNPU:
|
||||
return DT_NPU;
|
||||
default:
|
||||
return DT_CPU;
|
||||
}
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,11 +21,14 @@
|
|||
#include <string>
|
||||
#include "include/context.h"
|
||||
#include "include/api/context.h"
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
mindspore::Context *MSContextFromContext(const lite::Context *context);
|
||||
std::set<std::string> ProvidersFromMSContext(const mindspore::Context *context);
|
||||
bool DeviceTypePriority(const lite::Context *context, int device_type1, int device_type2);
|
||||
DeviceType KernelArchToDeviceType(kernel::KERNEL_ARCH kernel_arch);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_COMMON_CONTEXT_UTIL_H_
|
||||
|
|
|
@ -14,149 +14,127 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/cxx_api/converters.h"
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/context.h"
|
||||
#include "include/api/context.h"
|
||||
#include "src/runtime/inner_allocator.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr static int kMaxNumOfDevices = 2;
|
||||
constexpr static int kMaxNumOfDevices = 3;
|
||||
|
||||
Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) {
|
||||
if ((a_context == nullptr) || (l_context == nullptr)) {
|
||||
MS_LOG(ERROR) << "Invalid context pointers.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
void ContextUtils::SetContextAttr(int32_t thread_num, bool enable_parallel,
|
||||
const std::vector<int32_t> &affinity_core_list,
|
||||
const std::shared_ptr<Delegate> &delegate, lite::InnerContext *inner_context) {
|
||||
inner_context->thread_num_ = thread_num;
|
||||
inner_context->enable_parallel_ = enable_parallel;
|
||||
inner_context->affinity_core_list_ = affinity_core_list;
|
||||
inner_context->delegate = delegate;
|
||||
}
|
||||
|
||||
auto device_list = a_context->MutableDeviceInfo();
|
||||
if (device_list.size() == 0) {
|
||||
MS_LOG(ERROR) << "Invalid device list.";
|
||||
Status ContextUtils::AddCpuDevice(const std::shared_ptr<Allocator> &allocator, int affinity_mode, bool enable_fp16,
|
||||
const std::string &provider, const std::string &provider_device,
|
||||
lite::InnerContext *inner_context) {
|
||||
inner_context->allocator = allocator;
|
||||
if (!IsAffinityModeValid(affinity_mode)) {
|
||||
MS_LOG(ERROR) << "Invalid affinity mode, only supports 0:no affinities, 1:big cores first, 2:little cores first.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
if (device_list.size() > kMaxNumOfDevices) {
|
||||
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
l_context->thread_num_ = a_context->GetThreadNum();
|
||||
l_context->enable_parallel_ = a_context->GetEnableParallel();
|
||||
l_context->affinity_core_list_ = a_context->GetThreadAffinityCoreList();
|
||||
l_context->device_list_.clear();
|
||||
if (device_list[0]->GetDeviceType() != kCPU) {
|
||||
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
|
||||
auto cpu_context = device_list[0]->Cast<CPUDeviceInfo>();
|
||||
l_context->allocator = cpu_context->GetAllocator();
|
||||
if (l_context->allocator == nullptr) {
|
||||
l_context->allocator = Allocator::Create();
|
||||
if (l_context->allocator == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Allocator failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Set new allocator.";
|
||||
cpu_context->SetAllocator(l_context->allocator);
|
||||
}
|
||||
|
||||
if (!IsAffinityModeValid(a_context->GetThreadAffinityMode())) {
|
||||
MS_LOG(ERROR)
|
||||
<< "Invalid affinity mode, only supports 0: no affinities, 1: big cores first, 2: little cores first.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->GetThreadAffinityMode());
|
||||
|
||||
lite::DeviceInfo cpu_info = {0};
|
||||
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
|
||||
l_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
|
||||
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
|
||||
if (device_list.size() == kMaxNumOfDevices) {
|
||||
lite::DeviceInfo device_info = {0};
|
||||
if (device_list[1]->GetDeviceType() == kGPU) {
|
||||
auto gpu_context = device_list[1]->Cast<GPUDeviceInfo>();
|
||||
device_info.gpu_device_info_ = {gpu_context->GetEnableFP16()};
|
||||
l_context->device_list_.push_back({lite::DT_GPU, device_info, gpu_context->GetProvider(),
|
||||
gpu_context->GetProviderDevice(), gpu_context->GetAllocator()});
|
||||
} else if (device_list[1]->GetDeviceType() == kKirinNPU) {
|
||||
auto npu_context = device_list[1]->Cast<KirinNPUDeviceInfo>();
|
||||
device_info.npu_device_info_ = {npu_context->GetFrequency()};
|
||||
l_context->device_list_.push_back({lite::DT_NPU, device_info});
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid device.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
}
|
||||
l_context->delegate = a_context->GetDelegate();
|
||||
lite::DeviceInfo device_info = {0};
|
||||
device_info.cpu_device_info_ = {enable_fp16, static_cast<lite::CpuBindMode>(affinity_mode)};
|
||||
inner_context->device_list_.push_back({lite::DT_CPU, device_info, provider, provider_device, allocator});
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context) {
|
||||
if ((a_context == nullptr) || (l_context == nullptr)) {
|
||||
MS_LOG(ERROR) << "Invalid context pointers.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
auto device_list = a_context->device_info_list;
|
||||
if (device_list.size() == 0) {
|
||||
MS_LOG(ERROR) << "Invalid device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
if (device_list.size() > kMaxNumOfDevices) {
|
||||
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
l_context->thread_num_ = a_context->thread_num;
|
||||
l_context->enable_parallel_ = a_context->enable_parallel;
|
||||
l_context->affinity_core_list_ = a_context->affinity_core_list;
|
||||
l_context->device_list_.clear();
|
||||
if (device_list[0]->device_type != kMSDeviceTypeCPU) {
|
||||
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
|
||||
auto cpu_context = device_list[0];
|
||||
l_context->allocator = cpu_context->allocator;
|
||||
if (l_context->allocator == nullptr) {
|
||||
l_context->allocator = Allocator::Create();
|
||||
if (l_context->allocator == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Allocator failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Set new allocator.";
|
||||
cpu_context->allocator = l_context->allocator;
|
||||
}
|
||||
|
||||
if (!IsAffinityModeValid(a_context->affinity_mode)) {
|
||||
MS_LOG(ERROR)
|
||||
<< "Invalid affinity mode, only supports 0: no affinities, 1: big cores first, 2: little cores first.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
lite::CpuBindMode mode = A2L_ConvertAffinityMode(a_context->affinity_mode);
|
||||
|
||||
lite::DeviceInfo cpu_info = {0};
|
||||
cpu_info.cpu_device_info_ = {cpu_context->enable_fp16, mode};
|
||||
l_context->device_list_.push_back(
|
||||
{lite::DT_CPU, cpu_info, cpu_context->provider, cpu_context->provider_device, cpu_context->allocator});
|
||||
if (device_list.size() == kMaxNumOfDevices) {
|
||||
lite::DeviceInfo device_info = {0};
|
||||
if (device_list[1]->device_type == kMSDeviceTypeGPU) {
|
||||
auto gpu_context = device_list[1];
|
||||
device_info.gpu_device_info_ = {gpu_context->enable_fp16};
|
||||
l_context->device_list_.push_back(
|
||||
{lite::DT_GPU, device_info, gpu_context->provider, gpu_context->provider_device, gpu_context->allocator});
|
||||
} else if (device_list[1]->device_type == kMSDeviceTypeKirinNPU) {
|
||||
auto npu_context = device_list[1];
|
||||
device_info.npu_device_info_ = {npu_context->frequency};
|
||||
l_context->device_list_.push_back({lite::DT_NPU, device_info});
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid device.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
}
|
||||
l_context->delegate = a_context->delegate;
|
||||
Status ContextUtils::AddGpuDevice(bool enable_fp16, const std::string &provider, const std::string &provider_device,
|
||||
const std::shared_ptr<Allocator> &allocator, lite::InnerContext *inner_context) {
|
||||
lite::DeviceInfo device_info = {0};
|
||||
device_info.gpu_device_info_ = {enable_fp16};
|
||||
inner_context->device_list_.push_back({lite::DT_GPU, device_info, provider, provider_device, allocator});
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ContextUtils::AddNpuDevice(int frequency, lite::InnerContext *inner_context) {
|
||||
lite::DeviceInfo device_info = {0};
|
||||
device_info.npu_device_info_ = {frequency};
|
||||
inner_context->device_list_.push_back({lite::DT_NPU, device_info});
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
lite::InnerContext *ContextUtils::Convert(Context *context) {
|
||||
auto inner_context = std::make_unique<lite::InnerContext>();
|
||||
if ((context == nullptr) || (inner_context == nullptr)) {
|
||||
MS_LOG(ERROR) << "Invalid context pointers.";
|
||||
return nullptr;
|
||||
}
|
||||
auto device_list = context->MutableDeviceInfo();
|
||||
if (device_list.size() == 0 || device_list.size() > kMaxNumOfDevices) {
|
||||
MS_LOG(ERROR) << "Device num, support min: 1, max: " << kMaxNumOfDevices;
|
||||
return nullptr;
|
||||
}
|
||||
SetContextAttr(context->GetThreadNum(), context->GetEnableParallel(), context->GetThreadAffinityCoreList(),
|
||||
context->GetDelegate(), inner_context.get());
|
||||
inner_context->device_list_.clear();
|
||||
Status ret = kLiteError;
|
||||
for (auto &device : device_list) {
|
||||
if (device == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (device->GetDeviceType() == kCPU) {
|
||||
auto cpu_context = device->Cast<CPUDeviceInfo>();
|
||||
if (cpu_context->GetAllocator() == nullptr) {
|
||||
cpu_context->SetAllocator(Allocator::Create());
|
||||
}
|
||||
ret = AddCpuDevice(cpu_context->GetAllocator(), context->GetThreadAffinityMode(), cpu_context->GetEnableFP16(),
|
||||
cpu_context->GetProvider(), cpu_context->GetProviderDevice(), inner_context.get());
|
||||
} else if (device->GetDeviceType() == kGPU) {
|
||||
auto gpu_context = device->Cast<GPUDeviceInfo>();
|
||||
ret = AddGpuDevice(gpu_context->GetEnableFP16(), gpu_context->GetProvider(), gpu_context->GetProviderDevice(),
|
||||
gpu_context->GetAllocator(), inner_context.get());
|
||||
} else if (device->GetDeviceType() == kKirinNPU) {
|
||||
auto npu_context = device->Cast<KirinNPUDeviceInfo>();
|
||||
ret = AddNpuDevice(npu_context->GetFrequency(), inner_context.get());
|
||||
}
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Add device failed!";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return inner_context.release();
|
||||
}
|
||||
|
||||
lite::InnerContext *ContextUtils::Convert(const ContextC *context_c) {
|
||||
auto inner_context = std::make_unique<lite::InnerContext>();
|
||||
if ((context_c == nullptr) || (inner_context == nullptr)) {
|
||||
MS_LOG(ERROR) << "Invalid context pointers.";
|
||||
return nullptr;
|
||||
}
|
||||
auto device_list = context_c->device_info_list;
|
||||
if (device_list.size() == 0 || device_list.size() > kMaxNumOfDevices) {
|
||||
MS_LOG(ERROR) << "Device num, support min: 1, max: " << kMaxNumOfDevices;
|
||||
return nullptr;
|
||||
}
|
||||
SetContextAttr(context_c->thread_num, context_c->enable_parallel, context_c->affinity_core_list, context_c->delegate,
|
||||
inner_context.get());
|
||||
inner_context->device_list_.clear();
|
||||
Status ret = kLiteError;
|
||||
for (auto &device_info_c : device_list) {
|
||||
if (device_info_c == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (device_info_c->device_type == kMSDeviceTypeCPU) {
|
||||
if (device_info_c->allocator == nullptr) {
|
||||
device_info_c->allocator = Allocator::Create();
|
||||
}
|
||||
ret = AddCpuDevice(device_info_c->allocator, context_c->affinity_mode, device_info_c->enable_fp16,
|
||||
device_info_c->provider, device_info_c->provider_device, inner_context.get());
|
||||
} else if (device_info_c->device_type == kMSDeviceTypeGPU) {
|
||||
ret = AddGpuDevice(device_info_c->enable_fp16, device_info_c->provider, device_info_c->provider_device,
|
||||
device_info_c->allocator, inner_context.get());
|
||||
} else if (device_info_c->device_type == kMSDeviceTypeKirinNPU) {
|
||||
ret = AddNpuDevice(device_info_c->frequency, inner_context.get());
|
||||
}
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Add device failed!";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return inner_context.release();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -13,26 +13,38 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_CXX_API_CONVERTERS_H_
|
||||
#define MINDSPORE_LITE_SRC_CXX_API_CONVERTERS_H_
|
||||
|
||||
#include <limits.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/lite_types.h"
|
||||
#include "src/cxx_api/context.h"
|
||||
#include "include/api/cfg.h"
|
||||
#include "include/train/train_cfg.h"
|
||||
#include "src/inner_context.h"
|
||||
#include "src/c_api/context_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ContextUtils {
|
||||
public:
|
||||
static lite::InnerContext *Convert(Context *context);
|
||||
static lite::InnerContext *Convert(const ContextC *context_c);
|
||||
|
||||
namespace lite {
|
||||
struct Context;
|
||||
class TrainCfg;
|
||||
} // namespace lite
|
||||
|
||||
class Context;
|
||||
class TrainCfg;
|
||||
private:
|
||||
static void SetContextAttr(int32_t thread_num, bool enable_parallel, const std::vector<int32_t> &affinity_core_list,
|
||||
const std::shared_ptr<Delegate> &delegate, lite::InnerContext *inner_context);
|
||||
static Status AddCpuDevice(const std::shared_ptr<Allocator> &allocator, int affinity_mode, bool enable_fp16,
|
||||
const std::string &provider, const std::string &provider_device,
|
||||
lite::InnerContext *inner_context);
|
||||
static Status AddGpuDevice(bool enable_fp16, const std::string &provider, const std::string &provider_device,
|
||||
const std::shared_ptr<Allocator> &allocator, lite::InnerContext *inner_context);
|
||||
static Status AddNpuDevice(int frequency, lite::InnerContext *inner_context);
|
||||
static bool IsAffinityModeValid(int affinity_mode) {
|
||||
return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU;
|
||||
}
|
||||
};
|
||||
|
||||
inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) {
|
||||
if (qt == kNoQuant) {
|
||||
|
@ -44,25 +56,6 @@ inline lite::QuantizationType A2L_ConvertQT(mindspore::QuantizationType qt) {
|
|||
return lite::QT_DEFAULT;
|
||||
}
|
||||
|
||||
inline lite::CpuBindMode A2L_ConvertAffinityMode(int affinity_mode) {
|
||||
switch (affinity_mode) {
|
||||
case 0:
|
||||
return lite::NO_BIND;
|
||||
case 1:
|
||||
return lite::HIGHER_CPU;
|
||||
case 2:
|
||||
return lite::MID_CPU;
|
||||
default:
|
||||
return lite::NO_BIND;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool IsAffinityModeValid(int affinity_mode) {
|
||||
return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU;
|
||||
}
|
||||
|
||||
Status A2L_ConvertContext(Context *a_context, lite::Context *l_context);
|
||||
Status A2L_ConvertContext(const ContextC *a_context, lite::Context *l_context);
|
||||
Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg);
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -45,19 +45,16 @@ CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProt
|
|||
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &ms_context) {
|
||||
context_ = ms_context;
|
||||
lite::Context lite_context;
|
||||
auto status = A2L_ConvertContext(ms_context.get(), &lite_context);
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto session = std::shared_ptr<session::LiteSession>(
|
||||
session::LiteSession::CreateSession(static_cast<const char *>(model_data), data_size, &lite_context));
|
||||
auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(ContextUtils::Convert(ms_context.get())));
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
auto ret = session->LoadModelAndCompileByBuf(static_cast<const char *>(model_data), data_size);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init session failed";
|
||||
return kLiteError;
|
||||
}
|
||||
session_.swap(session);
|
||||
MS_LOG(DEBUG) << "Build model success.";
|
||||
return kSuccess;
|
||||
|
@ -65,18 +62,16 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
|
|||
|
||||
Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &ms_context) {
|
||||
lite::Context lite_context;
|
||||
auto status = A2L_ConvertContext(ms_context.get(), &lite_context);
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto session = std::shared_ptr<session::LiteSession>(lite::LiteSession::CreateSession(model_path, &lite_context));
|
||||
auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(ContextUtils::Convert(ms_context.get())));
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto ret = session->LoadModelAndCompileByPath(model_path);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init session failed";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
session_.swap(session);
|
||||
MS_LOG(DEBUG) << "Build model success.";
|
||||
return kSuccess;
|
||||
|
@ -94,16 +89,15 @@ Status ModelImpl::Build() {
|
|||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
lite::Context model_context;
|
||||
auto status = A2L_ConvertContext(context_.get(), &model_context);
|
||||
if (status != kSuccess) {
|
||||
auto *inner_context = ContextUtils::Convert(context_.get());
|
||||
if (inner_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to convert Context to Lite Context";
|
||||
return status;
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
auto create_callback = CreateTrainSessionCallbackHolder();
|
||||
if (create_callback != nullptr) {
|
||||
auto session = create_callback(graph_->graph_data_, cfg_, &model_context);
|
||||
auto session = create_callback(graph_->graph_data_, cfg_, inner_context);
|
||||
if (session != nullptr) {
|
||||
session_ = session;
|
||||
MS_LOG(DEBUG) << "Build model success.";
|
||||
|
@ -113,10 +107,11 @@ Status ModelImpl::Build() {
|
|||
|
||||
auto model = graph_->graph_data_->lite_model();
|
||||
if (model == nullptr || model->buf == nullptr) {
|
||||
delete inner_context;
|
||||
MS_LOG(ERROR) << "Lite model has been freed.";
|
||||
return kLiteError;
|
||||
}
|
||||
auto session = std::shared_ptr<session::LiteSession>(session::LiteSession::CreateSession(&model_context));
|
||||
auto session = std::shared_ptr<lite::LiteSession>(CreateLiteSession(inner_context));
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
return kLiteNullptr;
|
||||
|
@ -436,4 +431,21 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<
|
|||
auto ret = session_->Resize(inner_input, truncated_shape);
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
lite::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) {
|
||||
auto session = new (std::nothrow) lite::LiteSession();
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "create session failed";
|
||||
delete context;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto ret = session->Init(context);
|
||||
if (ret != mindspore::lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "init session failed";
|
||||
delete session;
|
||||
return nullptr;
|
||||
}
|
||||
return session;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,6 +29,8 @@
|
|||
#include "include/api/cell.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "src/cxx_api/graph/graph_data.h"
|
||||
#include "src/inner_context.h"
|
||||
#include "src/lite_session.h"
|
||||
|
||||
template <class T>
|
||||
void clearVectorOfPointers(std::vector<T> *v) {
|
||||
|
@ -42,9 +44,9 @@ void clearVectorOfPointers(std::vector<T> *v) {
|
|||
|
||||
namespace mindspore {
|
||||
|
||||
typedef std::shared_ptr<session::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data,
|
||||
std::shared_ptr<TrainCfg> cfg,
|
||||
lite::Context *context);
|
||||
typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data,
|
||||
std::shared_ptr<TrainCfg> cfg,
|
||||
lite::InnerContext *context);
|
||||
CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
|
||||
|
||||
namespace session {
|
||||
|
@ -65,6 +67,7 @@ class ModelImpl {
|
|||
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
|
||||
const MSKernelCallBack &after);
|
||||
lite::LiteSession *CreateLiteSession(lite::InnerContext *context);
|
||||
|
||||
std::vector<MSTensor> GetInputs();
|
||||
std::vector<MSTensor> GetOutputs();
|
||||
|
@ -81,6 +84,7 @@ class ModelImpl {
|
|||
return kSuccess;
|
||||
}
|
||||
std::vector<Metrics *> GetMetrics() { return metrics_; }
|
||||
const session::LiteSession *GetSession() const { return session_.get(); }
|
||||
|
||||
protected:
|
||||
// Utility methods
|
||||
|
@ -94,7 +98,7 @@ class ModelImpl {
|
|||
friend class Model;
|
||||
friend class Serialization;
|
||||
std::shared_ptr<Graph> graph_ = nullptr;
|
||||
std::shared_ptr<session::LiteSession> session_ = nullptr;
|
||||
std::shared_ptr<lite::LiteSession> session_ = nullptr;
|
||||
std::shared_ptr<Context> context_ = nullptr;
|
||||
std::shared_ptr<TrainCfg> cfg_ = nullptr;
|
||||
std::vector<Metrics *> metrics_;
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "include/api/callback/callback.h"
|
||||
#include "include/api/metrics/metrics.h"
|
||||
#include "src/lite_model.h"
|
||||
#include "src/inner_context.h"
|
||||
#include "src/runtime/inner_allocator.h"
|
||||
#include "src/common/string_util.h"
|
||||
#include "src/cxx_api/model/model_impl.h"
|
||||
|
@ -40,8 +41,8 @@
|
|||
#include "src/train/train_session.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::shared_ptr<session::LiteSession> CreateTrainSession(std::shared_ptr<Graph::GraphData> graph_data,
|
||||
std::shared_ptr<TrainCfg> cfg, lite::Context *context) {
|
||||
std::shared_ptr<lite::LiteSession> CreateTrainSession(std::shared_ptr<Graph::GraphData> graph_data,
|
||||
std::shared_ptr<TrainCfg> cfg, lite::InnerContext *context) {
|
||||
bool is_train_session = graph_data->IsTrainModel();
|
||||
if (is_train_session) {
|
||||
auto model = graph_data->lite_model();
|
||||
|
@ -49,8 +50,8 @@ std::shared_ptr<session::LiteSession> CreateTrainSession(std::shared_ptr<Graph::
|
|||
MS_LOG(ERROR) << "Lite model has been freed.";
|
||||
return nullptr;
|
||||
}
|
||||
std::shared_ptr<session::LiteSession> shared_session;
|
||||
lite::TrainSession *session = new lite::TrainSession();
|
||||
std::shared_ptr<lite::LiteSession> shared_session;
|
||||
auto session = new (std::nothrow) lite::TrainSession();
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "create session failed";
|
||||
return nullptr;
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include <algorithm>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#ifdef SUPPORT_NPU
|
||||
#include "include/HiAiModelManagerType.h"
|
||||
#endif
|
||||
|
@ -28,6 +27,8 @@
|
|||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr int kDefaultParallelNum = 2;
|
||||
constexpr int kMaxLiteContextDeviceNums = 2;
|
||||
constexpr int kMaxInnerContextDeviceNums = 3;
|
||||
} // namespace
|
||||
|
||||
InnerContext::InnerContext(const Context *context) {
|
||||
|
@ -46,24 +47,54 @@ InnerContext::InnerContext(const Context *context) {
|
|||
}
|
||||
|
||||
void InnerContext::SetContextDevice(const Context *context) {
|
||||
this->device_list_.clear();
|
||||
|
||||
if (context->device_list_.size() > kMaxLiteContextDeviceNums || context->device_list_.size() <= 0) {
|
||||
return;
|
||||
}
|
||||
if (context->device_list_.front().device_type_ != DT_CPU) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* user set order for different device */
|
||||
if (context->device_list_.size() < kMaxLiteContextDeviceNums) {
|
||||
this->device_list_.push_back(context->device_list_.front());
|
||||
return;
|
||||
}
|
||||
|
||||
/* keep compatibility :
|
||||
* if user set CPU & NPU/GPU
|
||||
* NPU/GPU higher priority */
|
||||
bool isUserSetNPU = context->device_list_.end() !=
|
||||
std::find_if(context->device_list_.begin(), context->device_list_.end(),
|
||||
std::find_if(this->device_list_.begin(), this->device_list_.end(),
|
||||
[](const DeviceContext &device) { return device.device_type_ == DT_NPU; });
|
||||
bool isUserSetGPU = context->device_list_.end() !=
|
||||
std::find_if(context->device_list_.begin(), context->device_list_.end(),
|
||||
std::find_if(this->device_list_.begin(), this->device_list_.end(),
|
||||
[](const DeviceContext &device) { return device.device_type_ == DT_GPU; });
|
||||
this->device_list_.clear();
|
||||
if (isUserSetGPU == false && isUserSetNPU == false) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* add GPU/NPU first */
|
||||
for (auto &device_ctx : context->device_list_) {
|
||||
// npu/gpu server would use one core so we don't bind core to avoid competition.
|
||||
// If user does not set npu/gpu device, we still bind core.
|
||||
if (device_ctx.device_type_ == DT_CPU && (isUserSetNPU || (isUserSetGPU && !enable_parallel_))) {
|
||||
auto cpu_ctx = device_ctx;
|
||||
cpu_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
|
||||
this->device_list_.push_back(cpu_ctx);
|
||||
} else {
|
||||
if (device_ctx.device_type_ != DT_CPU) {
|
||||
this->device_list_.push_back(device_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
/* add CPU */
|
||||
for (auto &device_ctx : context->device_list_) {
|
||||
if (device_ctx.device_type_ == DT_CPU) {
|
||||
if (isUserSetNPU || (isUserSetGPU && enable_parallel_ == false)) {
|
||||
auto cpu_ctx = device_ctx;
|
||||
cpu_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND;
|
||||
this->device_list_.push_back(cpu_ctx);
|
||||
} else {
|
||||
this->device_list_.push_back(device_ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int InnerContext::Init() {
|
||||
|
@ -71,12 +102,11 @@ int InnerContext::Init() {
|
|||
MS_LOG(ERROR) << "Context is not valid";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
if (this->thread_pool_ == nullptr && this->IsCpuEnabled()) {
|
||||
if (this->thread_pool_ == nullptr) {
|
||||
int actor_parallel_thread = this->enable_parallel_ ? kDefaultParallelNum : 1;
|
||||
|
||||
if (this->affinity_core_list_.empty()) {
|
||||
auto bind_mode = static_cast<BindMode>(this->device_list_.front().device_info_.cpu_device_info_.cpu_bind_mode_);
|
||||
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_, bind_mode);
|
||||
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_,
|
||||
static_cast<BindMode>(GetCpuBindMode()));
|
||||
if (thread_pool_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create ThreadPool failed";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -131,8 +161,8 @@ int InnerContext::IsValid() const {
|
|||
MS_LOG(ERROR) << "Device list is empty.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
if (this->device_list_.size() > kMaxDeviceNums) {
|
||||
MS_LOG(ERROR) << "Not support device list more than 2.";
|
||||
if (this->device_list_.size() > kMaxInnerContextDeviceNums) {
|
||||
MS_LOG(ERROR) << "Not support device list more than " << kMaxInnerContextDeviceNums;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
if (thread_num_ < 1) {
|
||||
|
@ -144,11 +174,6 @@ int InnerContext::IsValid() const {
|
|||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
if (!IsUserSetCpu()) {
|
||||
MS_LOG(ERROR) << "CPU context should be set.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
if (IsCpuBindModeInvalid()) {
|
||||
MS_LOG(ERROR) << "CPU bind mode should be one of NO_BIND, HIGHER_CPU or MID_CPU.";
|
||||
return RET_NOT_SUPPORT;
|
||||
|
@ -196,6 +221,13 @@ bool InnerContext::IsGpuFloat16Enabled() const {
|
|||
|
||||
bool InnerContext::IsCpuEnabled() const { return IsUserSetCpu(); }
|
||||
|
||||
int InnerContext::GetCpuBindMode() const {
|
||||
auto iter = std::find_if(device_list_.begin(), device_list_.end(), [](const DeviceContext &device) {
|
||||
return (device.device_type_ == DeviceType::DT_CPU) ? true : false;
|
||||
});
|
||||
return iter != device_list_.end() ? iter->device_info_.cpu_device_info_.cpu_bind_mode_ : NO_BIND;
|
||||
}
|
||||
|
||||
bool InnerContext::IsGpuEnabled() const {
|
||||
#ifdef SUPPORT_GPU
|
||||
return IsUserSetGpu();
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
const constexpr int kMaxDeviceNums = 2;
|
||||
struct InnerContext : public Context {
|
||||
public:
|
||||
InnerContext() = default;
|
||||
|
@ -41,6 +40,8 @@ struct InnerContext : public Context {
|
|||
|
||||
bool IsCpuEnabled() const;
|
||||
|
||||
int GetCpuBindMode() const;
|
||||
|
||||
bool IsGpuEnabled() const;
|
||||
|
||||
bool IsNpuEnabled() const;
|
||||
|
@ -82,7 +83,6 @@ struct InnerContext : public Context {
|
|||
};
|
||||
|
||||
int ParallelLaunch(const Context *context, const Func &func, Content content, int task_num);
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_INNER_CONTEXT_H
|
||||
|
|
|
@ -655,7 +655,7 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af
|
|||
return ret;
|
||||
}
|
||||
|
||||
int LiteSession::Init(const Context *context) {
|
||||
int LiteSession::Init(InnerContext *context) {
|
||||
bool expected = false;
|
||||
if (!is_running_.compare_exchange_strong(expected, true)) {
|
||||
MS_LOG(ERROR) << "Not support multi-threading";
|
||||
|
@ -666,20 +666,22 @@ int LiteSession::Init(const Context *context) {
|
|||
is_running_.store(false);
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
this->context_ = new (std::nothrow) InnerContext(context);
|
||||
if (this->context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New Context failed";
|
||||
is_running_.store(false);
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
this->context_ = context;
|
||||
auto ret = this->context_->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init Context failed";
|
||||
is_running_.store(false);
|
||||
return ret;
|
||||
}
|
||||
if (context->delegate != nullptr) {
|
||||
delegate_ = context->delegate;
|
||||
ms_context_ = MSContextFromContext(context_);
|
||||
if (ms_context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "transfer context to ms context failed.";
|
||||
is_running_.store(false);
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (context_->delegate != nullptr) {
|
||||
delegate_ = context_->delegate;
|
||||
}
|
||||
#if SUPPORT_NPU
|
||||
if (delegate_ == nullptr && context_->IsNpuEnabled()) {
|
||||
|
@ -688,7 +690,6 @@ int LiteSession::Init(const Context *context) {
|
|||
MS_LOG(ERROR) << "New delegate_ failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->context_->delegate = delegate_;
|
||||
}
|
||||
#endif
|
||||
#if GPU_TENSORRT
|
||||
|
@ -698,7 +699,6 @@ int LiteSession::Init(const Context *context) {
|
|||
MS_LOG(ERROR) << "New tensorrt delegate_ failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->context_->delegate = delegate_;
|
||||
}
|
||||
#endif
|
||||
if (delegate_ != nullptr) {
|
||||
|
@ -720,12 +720,6 @@ int LiteSession::Init(const Context *context) {
|
|||
is_running_.store(false);
|
||||
return ret;
|
||||
}
|
||||
ms_context_ = MSContextFromContext(context);
|
||||
if (ms_context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "transfer context to ms context failed.";
|
||||
is_running_.store(false);
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
is_running_.store(false);
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -921,14 +915,15 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
|
|||
}
|
||||
|
||||
int LiteSession::InitGPURuntime() {
|
||||
CpuBindMode cpu_bind_mode = this->context_->device_list_.front().device_info_.cpu_device_info_.cpu_bind_mode_;
|
||||
ActorThreadPool *thread_pool = this->context_->thread_pool();
|
||||
if (thread_pool == nullptr) {
|
||||
MS_LOG(ERROR) << "thread pool is nullptr";
|
||||
is_running_.store(false);
|
||||
return RET_NULL_PTR;
|
||||
if (context_->IsCpuEnabled()) {
|
||||
ActorThreadPool *thread_pool = this->context_->thread_pool();
|
||||
if (thread_pool == nullptr) {
|
||||
MS_LOG(ERROR) << "thread pool is nullptr";
|
||||
is_running_.store(false);
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
thread_pool->SetProcessAffinity(static_cast<BindMode>(static_cast<BindMode>(context_->GetCpuBindMode())));
|
||||
}
|
||||
thread_pool->SetProcessAffinity(static_cast<BindMode>(cpu_bind_mode));
|
||||
#if GPU_OPENCL
|
||||
if (this->context_->IsGpuEnabled()) {
|
||||
opencl_runtime_wrapper_ = new (std::nothrow) opencl::OpenCLRuntimeWrapper();
|
||||
|
@ -971,7 +966,10 @@ int LiteSession::InitGPURuntime() {
|
|||
}
|
||||
#endif
|
||||
// Setting the binding core will affect the opencl drive scheduling.
|
||||
thread_pool->SetProcessAffinity(static_cast<BindMode>(NO_BIND));
|
||||
if (context_->IsCpuEnabled()) {
|
||||
ActorThreadPool *thread_pool = this->context_->thread_pool();
|
||||
thread_pool->SetProcessAffinity(static_cast<BindMode>(NO_BIND));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
@ -982,7 +980,7 @@ session::LiteSession *session::LiteSession::CreateSession(const lite::Context *c
|
|||
MS_LOG(ERROR) << "create session failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = session->Init(context);
|
||||
auto ret = session->Init(new (std::nothrow) mindspore::lite::InnerContext(context));
|
||||
if (ret != mindspore::lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "init session failed";
|
||||
delete session;
|
||||
|
@ -998,52 +996,67 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf,
|
|||
MS_LOG(ERROR) << "Create session failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto *model = lite::ImportFromBuffer(model_buf, size, true);
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
auto ret = reinterpret_cast<lite::LiteSession *>(session)->LoadModelAndCompileByBuf(model_buf, size);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init session failed";
|
||||
delete session;
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = session->CompileGraph(model);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Compile model failed";
|
||||
model->buf = nullptr;
|
||||
delete model;
|
||||
delete session;
|
||||
return nullptr;
|
||||
}
|
||||
model->buf = nullptr;
|
||||
(reinterpret_cast<lite::LiteSession *>(session))->set_model(model);
|
||||
return session;
|
||||
}
|
||||
|
||||
session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_path, const lite::Context *context) {
|
||||
size_t model_size;
|
||||
auto model_buf = lite::ReadFile(model_path.c_str(), &model_size);
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model file failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto *session = session::LiteSession::CreateSession(context);
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Create session failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto *model = lite::ImportFromBuffer(model_buf, model_size, true);
|
||||
if (model == nullptr) {
|
||||
auto ret = reinterpret_cast<lite::LiteSession *>(session)->LoadModelAndCompileByPath(model_path);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init session failed";
|
||||
delete session;
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
return nullptr;
|
||||
}
|
||||
(reinterpret_cast<lite::LiteModel *>(model))->set_keep_model_buf(true);
|
||||
auto ret = session->CompileGraph(model);
|
||||
if (ret != lite::RET_OK) {
|
||||
delete model;
|
||||
delete session;
|
||||
MS_LOG(ERROR) << "Compile model failed";
|
||||
return nullptr;
|
||||
}
|
||||
(reinterpret_cast<lite::LiteSession *>(session))->set_model(model);
|
||||
return session;
|
||||
}
|
||||
|
||||
int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, size_t buf_size) {
|
||||
auto *model = lite::ImportFromBuffer(model_buf, buf_size, true);
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = CompileGraph(model);
|
||||
model->buf = nullptr;
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Compile model failed";
|
||||
delete model;
|
||||
return RET_ERROR;
|
||||
}
|
||||
set_model(model);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path) {
|
||||
size_t model_size;
|
||||
auto model_buf = lite::ReadFile(model_path.c_str(), &model_size);
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model file failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto *model = lite::ImportFromBuffer(model_buf, model_size, true);
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
(reinterpret_cast<lite::LiteModel *>(model))->set_keep_model_buf(true);
|
||||
auto ret = CompileGraph(model);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Compile model failed";
|
||||
delete model;
|
||||
return RET_ERROR;
|
||||
}
|
||||
set_model(model);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,7 +49,10 @@ class LiteSession : public session::LiteSession {
|
|||
|
||||
static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
|
||||
|
||||
virtual int Init(const Context *context);
|
||||
int LoadModelAndCompileByBuf(const char *model_buf, size_t buf_size);
|
||||
int LoadModelAndCompileByPath(const std::string &model_path);
|
||||
|
||||
virtual int Init(InnerContext *context);
|
||||
|
||||
void BindThread(bool if_bind) override;
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "src/common/version_manager.h"
|
||||
#include "src/common/prim_util.h"
|
||||
#include "src/common/tensor_util.h"
|
||||
#include "src/common/context_util.h"
|
||||
#include "src/runtime/infer_manager.h"
|
||||
#include "src/sub_graph_split.h"
|
||||
#include "src/weight_decoder.h"
|
||||
|
@ -234,6 +235,19 @@ int Scheduler::InitKernels(std::vector<kernel::LiteKernel *> dst_kernels) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Scheduler::CheckCpuValid(const std::vector<kernel::LiteKernel *> *dst_kernels) const {
|
||||
if (context_->IsCpuEnabled()) {
|
||||
return RET_OK;
|
||||
}
|
||||
for (auto kernel : *dst_kernels) {
|
||||
if (kernel->desc().arch == kernel::KERNEL_ARCH::kCPU) {
|
||||
MS_LOG(ERROR) << "kernel: " << kernel->name() << " only support in CPU.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
||||
if (dst_kernels == nullptr) {
|
||||
return RET_ERROR;
|
||||
|
@ -270,16 +284,19 @@ int Scheduler::Schedule(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
|||
MS_LOG(ERROR) << "Schedule graph to kernels failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
SetSubgraphForPartialNode();
|
||||
if (delegate_ != nullptr) {
|
||||
ret = ReplaceDelegateKernels(dst_kernels);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Repalce delegate kernels failed.";
|
||||
return ret;
|
||||
}
|
||||
context_->thread_pool()->SetMaxSpinCount(kMinSpinCount);
|
||||
|
||||
ret = InitDelegateKernels(dst_kernels);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Repalce delegate kernels failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = CheckCpuValid(dst_kernels);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "kernels invalid in set devices.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
FindAllInoutKernels(*dst_kernels);
|
||||
|
||||
if (IsControlFlowParttern(*dst_kernels)) {
|
||||
|
@ -378,6 +395,55 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_ker
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Scheduler::InitDelegateKernels(std::vector<kernel::LiteKernel *> *dst_kernels) {
|
||||
/* no delegate valid */
|
||||
if (delegate_ == nullptr) {
|
||||
return RET_OK;
|
||||
}
|
||||
/* set delegate spin count */
|
||||
context_->thread_pool()->SetMaxSpinCount(kMinSpinCount);
|
||||
/* Inner delegate : check Priority */
|
||||
std::vector<kernel::LiteKernel *> src_kernels = *dst_kernels;
|
||||
dst_kernels->clear();
|
||||
|
||||
while (!src_kernels.empty()) {
|
||||
std::vector<kernel::LiteKernel *> tmp_kernels;
|
||||
kernel::LiteKernel *remain_kernel = nullptr;
|
||||
/* Loop for inner delegate npu and TensorRT subgraph */
|
||||
while (!src_kernels.empty()) {
|
||||
auto kernel = src_kernels.front();
|
||||
VectorErase(&src_kernels, kernel);
|
||||
bool priority_ret = DeviceTypePriority(context_, DT_NPU, KernelArchToDeviceType(kernel->desc().arch));
|
||||
if (priority_ret == true) {
|
||||
tmp_kernels.push_back(kernel);
|
||||
} else {
|
||||
remain_kernel = kernel;
|
||||
break;
|
||||
}
|
||||
}
|
||||
/* start current NPU-kernels replace */
|
||||
if (tmp_kernels.empty()) {
|
||||
if (remain_kernel != nullptr) {
|
||||
dst_kernels->push_back(remain_kernel);
|
||||
remain_kernel = nullptr;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto ret = ReplaceDelegateKernels(&tmp_kernels);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "NPU delegate repalce delegate kernels failed.";
|
||||
return ret;
|
||||
}
|
||||
dst_kernels->insert(dst_kernels->end(), tmp_kernels.begin(), tmp_kernels.end());
|
||||
tmp_kernels.clear();
|
||||
if (remain_kernel != nullptr) {
|
||||
dst_kernels->push_back(remain_kernel);
|
||||
remain_kernel = nullptr;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void Scheduler::FindNodeInoutTensors(const lite::Model::Node &node, std::vector<Tensor *> *inputs,
|
||||
std::vector<Tensor *> *outputs) {
|
||||
MS_ASSERT(inputs != nullptr);
|
||||
|
@ -914,7 +980,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
op_parameter->is_train_session_ = is_train_session_;
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
|
||||
#ifdef GPU_OPENCL
|
||||
if (node->device_type_ == DT_GPU || node->device_type_ == kDefaultDeviceType) {
|
||||
bool gpu_priority = DeviceTypePriority(context_, DT_GPU, DT_CPU);
|
||||
bool use_gpu_kernel = node->device_type_ == DT_GPU || node->device_type_ == kDefaultDeviceType;
|
||||
if (gpu_priority && use_gpu_kernel) {
|
||||
status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel);
|
||||
if (status == RET_OK) {
|
||||
return kernel;
|
||||
|
|
|
@ -69,12 +69,15 @@ class Scheduler {
|
|||
int FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type,
|
||||
kernel::LiteKernel **kernel);
|
||||
int CheckCpuValid(const std::vector<kernel::LiteKernel *> *dst_kernels) const;
|
||||
int FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel);
|
||||
int FindProviderKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const Model::Node *node, TypeId data_type, kernel::LiteKernel **kernel);
|
||||
|
||||
int ReplaceDelegateKernels(std::vector<kernel::LiteKernel *> *dst_kernels);
|
||||
int InitDelegateKernels(std::vector<kernel::LiteKernel *> *dst_kernels);
|
||||
|
||||
int InitKernels(std::vector<kernel::LiteKernel *> dst_kernels);
|
||||
kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node);
|
||||
// schedule a partial node to a subgraph_kernel
|
||||
|
|
|
@ -51,7 +51,7 @@ TrainSession::TrainSession() {
|
|||
InitCallBack();
|
||||
}
|
||||
|
||||
int TrainSession::Init(const Context *context, const TrainCfg *train_cfg) {
|
||||
int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) {
|
||||
if (train_cfg != nullptr) {
|
||||
if (train_cfg->mix_precision_cfg_.loss_scale_ <= 0) {
|
||||
MS_LOG(ERROR) << "illegal loss scale configuration";
|
||||
|
@ -769,7 +769,7 @@ session::LiteSession *session::TrainSession::CreateTrainSession(const std::strin
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto ret = session->Init(context, cfg);
|
||||
auto ret = session->Init(new (std::nothrow) mindspore::lite::InnerContext(context), cfg);
|
||||
if (ret != mindspore::lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "init session failed";
|
||||
return nullptr;
|
||||
|
|
|
@ -54,7 +54,7 @@ class TrainSession : virtual public lite::LiteSession {
|
|||
int CompileGraph(lite::Model *model) override;
|
||||
virtual int CompileTrainGraph(std::shared_ptr<Model> model);
|
||||
|
||||
virtual int Init(const Context *context, const TrainCfg *train_cfg);
|
||||
virtual int Init(InnerContext *context, const TrainCfg *train_cfg);
|
||||
|
||||
int Train() override;
|
||||
int Eval() override;
|
||||
|
|
|
@ -248,7 +248,7 @@ static session::LiteSession *CreateTransferSessionInt(const char *model_buf_back
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto ret = session->Init(context, cfg);
|
||||
auto ret = session->Init(new (std::nothrow) mindspore::lite::InnerContext(context), cfg);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "init transfer session failed";
|
||||
delete session;
|
||||
|
|
|
@ -97,21 +97,20 @@ int BenchmarkCApi::InitContext() {
|
|||
MSContextSetEnableParallel(context_, flags_->enable_parallel_);
|
||||
MSContextSetThreadAffinityMode(context_, flags_->cpu_bind_mode_);
|
||||
|
||||
MSDeviceInfoHandle cpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeCPU);
|
||||
MSDeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_);
|
||||
MSContextAddDeviceInfo(context_, cpu_device_info);
|
||||
|
||||
if (flags_->device_ == "GPU") {
|
||||
MSDeviceInfoHandle gpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeGPU);
|
||||
MSDeviceInfoSetEnableFP16(gpu_device_info, flags_->enable_fp16_);
|
||||
MSContextAddDeviceInfo(context_, gpu_device_info);
|
||||
}
|
||||
|
||||
if (flags_->device_ == "NPU") {
|
||||
MSDeviceInfoHandle npu_device_info = MSDeviceInfoCreate(kMSDeviceTypeKirinNPU);
|
||||
MSDeviceInfoSetFrequency(npu_device_info, kFrequencyDefault);
|
||||
MSContextAddDeviceInfo(context_, npu_device_info);
|
||||
}
|
||||
|
||||
MSDeviceInfoHandle cpu_device_info = MSDeviceInfoCreate(kMSDeviceTypeCPU);
|
||||
MSDeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_);
|
||||
MSContextAddDeviceInfo(context_, cpu_device_info);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -127,10 +127,6 @@ void BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context
|
|||
context->SetThreadAffinity(flags_->cpu_bind_mode_);
|
||||
auto &device_list = context->MutableDeviceInfo();
|
||||
|
||||
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
|
||||
device_info->SetEnableFP16(flags_->enable_fp16_);
|
||||
device_list.push_back(device_info);
|
||||
|
||||
if (flags_->device_ == "GPU") {
|
||||
std::shared_ptr<GPUDeviceInfo> gpu_device_info = std::make_shared<GPUDeviceInfo>();
|
||||
gpu_device_info->SetEnableFP16(flags_->enable_fp16_);
|
||||
|
@ -142,6 +138,11 @@ void BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context
|
|||
npu_device_info->SetFrequency(kFrequencyDefault);
|
||||
device_list.push_back(npu_device_info);
|
||||
}
|
||||
|
||||
// CPU priority is behind GPU and NPU
|
||||
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
|
||||
device_info->SetEnableFP16(flags_->enable_fp16_);
|
||||
device_list.push_back(device_info);
|
||||
}
|
||||
|
||||
int BenchmarkUnifiedApi::CompareOutput() {
|
||||
|
|
Loading…
Reference in New Issue