model pool support ascend

This commit is contained in:
zhengyuanhua 2022-08-09 14:25:14 +08:00
parent 276654b9e4
commit 3b1a5e435c
2 changed files with 12 additions and 7 deletions

View File

@ -344,10 +344,7 @@ Status ModelPool::CheckAffinityCoreList(const std::shared_ptr<RunnerConfig> &run
std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<RunnerConfig> &runner_config) {
auto context = runner_config->GetContext();
if (context == nullptr) {
MS_LOG(ERROR) << "user set config context nullptr.";
return nullptr;
}
MS_CHECK_TRUE_MSG(context != nullptr, nullptr, "user set config context nullptr.");
auto user_thread_num = context->GetThreadNum();
if (user_thread_num < 0) {
MS_LOG(ERROR) << "Invalid thread num " << context->GetThreadNum();
@ -374,8 +371,8 @@ std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<R
}
for (size_t i = 0; i < device_list.size(); i++) {
auto device = device_list[i];
if (device->GetDeviceType() != kCPU && device->GetDeviceType() != kGPU) {
MS_LOG(ERROR) << "model pool only support cpu or gpu type.";
if (device->GetDeviceType() != kCPU && device->GetDeviceType() != kGPU && device->GetDeviceType() != kAscend) {
MS_LOG(ERROR) << "model pool only support cpu or gpu or ascend type.";
return nullptr;
}
if (device->GetDeviceType() == kGPU && device_list.size() == kNumDeviceInfo) {
@ -395,6 +392,11 @@ std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<R
MS_LOG(ERROR) << "model pool not support enable fp16.";
return nullptr;
}
} else if (device->GetDeviceType() == kAscend) {
if (context->GetInterOpParallelNum() == 0) {
context->SetInterOpParallelNum(1); // do not use InterOpParallel
}
return context;
} else {
MS_LOG(ERROR) << "context is invalid; If you want run in GPU, you must set gpu device first, and then set cpu "
"device";

View File

@ -126,8 +126,11 @@ SessionConfig InferSession::SelectSessionArg(const std::shared_ptr<Context> &con
auto delegate_config = std::make_shared<mindspore::DelegateConfig>(context);
auto &device_contexts = context->MutableDeviceInfo();
for (auto device_context : device_contexts) {
// delegate init
MS_EXCEPTION_IF_NULL(device_context);
if (device_context->GetDeviceType() == kAscend) {
config.type_ = kSingleOpSession;
return config;
}
// get graph executor delegate
auto delegate = mindspore::DelegateRegistry::GetInstance().GetDelegate(
device_context->GetDeviceType(), device_context->GetProvider(), delegate_config);