diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index e9ff347fa3d..6b9cfd9d370 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -611,6 +611,12 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { ScopePtr scope = node->scope(); MS_EXCEPTION_IF_NULL(scope); replace_node->set_scope(scope); + PrimitivePtr prim = GetValueNode(replace_node->input(0)); + if (prim->name() == EMBEDDING_LOOKUP) { + auto attrs = prim->attrs(); + attrs[TARGET] = MakeValue(CPU); + (void)prim->SetAttrs(attrs); + } if (index == replace_op.size() - 1) { (void)replace_node->set_operator_info(node->operator_info()); } diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index a0887886a08..3c4245d7020 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer from ..cell import Cell from ..._checkparam import Validator as validator -__all__ = ['Embedding'] +__all__ = ['Embedding', 'EmbeddingLookup'] class Embedding(Cell): r""" @@ -147,7 +147,7 @@ class EmbeddingLookup(Cell): def construct(self, params, indices): if self.target == "CPU": - out = self.embeddinglookup(params, ids, 0) + out = self.embeddinglookup(params, indices, 0) else: - out = self.gatherv2(param, ids, 0) + out = self.gatherv2(params, indices, 0) return out