!3068 [AutoParallel]Fix EmbeddingLookup bug

Merge pull request !3068 from lichen/fix_embeddinglookup
This commit is contained in:
mindspore-ci-bot 2020-07-15 14:24:49 +08:00 committed by Gitee
commit 7e5e868d97
2 changed files with 9 additions and 3 deletions

View File

@ -611,6 +611,12 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope);
replace_node->set_scope(scope);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
if (prim->name() == EMBEDDING_LOOKUP) {
auto attrs = prim->attrs();
attrs[TARGET] = MakeValue(CPU);
(void)prim->SetAttrs(attrs);
}
if (index == replace_op.size() - 1) {
(void)replace_node->set_operator_info(node->operator_info());
}

View File

@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer
from ..cell import Cell
from ..._checkparam import Validator as validator
__all__ = ['Embedding']
__all__ = ['Embedding', 'EmbeddingLookup']
class Embedding(Cell):
r"""
@ -147,7 +147,7 @@ class EmbeddingLookup(Cell):
def construct(self, params, indices):
if self.target == "CPU":
out = self.embeddinglookup(params, ids, 0)
out = self.embeddinglookup(params, indices, 0)
else:
out = self.gatherv2(param, ids, 0)
out = self.gatherv2(params, indices, 0)
return out