model pool support ascend
This commit is contained in:
parent
276654b9e4
commit
3b1a5e435c
|
@ -344,10 +344,7 @@ Status ModelPool::CheckAffinityCoreList(const std::shared_ptr<RunnerConfig> &run
|
|||
|
||||
std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<RunnerConfig> &runner_config) {
|
||||
auto context = runner_config->GetContext();
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "user set config context nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(context != nullptr, nullptr, "user set config context nullptr.");
|
||||
auto user_thread_num = context->GetThreadNum();
|
||||
if (user_thread_num < 0) {
|
||||
MS_LOG(ERROR) << "Invalid thread num " << context->GetThreadNum();
|
||||
|
@ -374,8 +371,8 @@ std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<R
|
|||
}
|
||||
for (size_t i = 0; i < device_list.size(); i++) {
|
||||
auto device = device_list[i];
|
||||
if (device->GetDeviceType() != kCPU && device->GetDeviceType() != kGPU) {
|
||||
MS_LOG(ERROR) << "model pool only support cpu or gpu type.";
|
||||
if (device->GetDeviceType() != kCPU && device->GetDeviceType() != kGPU && device->GetDeviceType() != kAscend) {
|
||||
MS_LOG(ERROR) << "model pool only support cpu or gpu or ascend type.";
|
||||
return nullptr;
|
||||
}
|
||||
if (device->GetDeviceType() == kGPU && device_list.size() == kNumDeviceInfo) {
|
||||
|
@ -395,6 +392,11 @@ std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<R
|
|||
MS_LOG(ERROR) << "model pool not support enable fp16.";
|
||||
return nullptr;
|
||||
}
|
||||
} else if (device->GetDeviceType() == kAscend) {
|
||||
if (context->GetInterOpParallelNum() == 0) {
|
||||
context->SetInterOpParallelNum(1); // do not use InterOpParallel
|
||||
}
|
||||
return context;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "context is invalid; If you want run in GPU, you must set gpu device first, and then set cpu "
|
||||
"device";
|
||||
|
|
|
@ -126,8 +126,11 @@ SessionConfig InferSession::SelectSessionArg(const std::shared_ptr<Context> &con
|
|||
auto delegate_config = std::make_shared<mindspore::DelegateConfig>(context);
|
||||
auto &device_contexts = context->MutableDeviceInfo();
|
||||
for (auto device_context : device_contexts) {
|
||||
// delegate init
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if (device_context->GetDeviceType() == kAscend) {
|
||||
config.type_ = kSingleOpSession;
|
||||
return config;
|
||||
}
|
||||
// get graph executor delegate
|
||||
auto delegate = mindspore::DelegateRegistry::GetInstance().GetDelegate(
|
||||
device_context->GetDeviceType(), device_context->GetProvider(), delegate_config);
|
||||
|
|
Loading…
Reference in New Issue