fix embeddinglookup error on cpu pynative mode

This commit is contained in:
chujinjin 2020-12-15 20:01:11 +08:00
parent 0c88e3f256
commit 98d6898438
1 changed files with 5 additions and 2 deletions

View File

@ -422,11 +422,14 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name; MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name;
reg_exist = false; reg_exist = false;
} }
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
reg_exist = false; if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) {
reg_exist = false;
}
} }
if (op_run_info->op_name == prim::kPrimGatherD->name()) { if (op_run_info->op_name == prim::kPrimGatherD->name()) {
auto ms_context = MsContext::GetInstance();
// Gather op needs converting const input to attr on GPU device // Gather op needs converting const input to attr on GPU device
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
reg_exist = false; reg_exist = false;