fix embeddinglookup bug

This commit is contained in:
lichenever 2020-07-15 09:26:52 +08:00
parent bc0a53cfb1
commit 6dbb26967e
2 changed files with 9 additions and 3 deletions

View File

@ -611,6 +611,12 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope);
replace_node->set_scope(scope);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
if (prim->name() == EMBEDDING_LOOKUP) {
auto attrs = prim->attrs();
attrs[TARGET] = MakeValue(CPU);
(void)prim->SetAttrs(attrs);
}
if (index == replace_op.size() - 1) {
(void)replace_node->set_operator_info(node->operator_info());
}

View File

@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer
from ..cell import Cell
from ..._checkparam import Validator as validator
__all__ = ['Embedding']
__all__ = ['Embedding', 'EmbeddingLookup']
class Embedding(Cell):
r"""
@ -147,7 +147,7 @@ class EmbeddingLookup(Cell):
def construct(self, params, indices):
if self.target == "CPU":
out = self.embeddinglookup(params, ids, 0)
out = self.embeddinglookup(params, indices, 0)
else:
out = self.gatherv2(param, ids, 0)
out = self.gatherv2(params, indices, 0)
return out