forked from mindspore-Ecosystem/mindspore
fix embeddinglookup error on cpu pynative mode
This commit is contained in:
parent
0c88e3f256
commit
98d6898438
|
@ -422,11 +422,14 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
|
||||||
MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name;
|
MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name;
|
||||||
reg_exist = false;
|
reg_exist = false;
|
||||||
}
|
}
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
|
if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
|
||||||
|
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) {
|
||||||
reg_exist = false;
|
reg_exist = false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if (op_run_info->op_name == prim::kPrimGatherD->name()) {
|
if (op_run_info->op_name == prim::kPrimGatherD->name()) {
|
||||||
auto ms_context = MsContext::GetInstance();
|
|
||||||
// Gather op needs converting const input to attr on GPU device
|
// Gather op needs converting const input to attr on GPU device
|
||||||
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||||
reg_exist = false;
|
reg_exist = false;
|
||||||
|
|
Loading…
Reference in New Issue