forked from mindspore-Ecosystem/mindspore
fix embeddinglookup bug
This commit is contained in:
parent
bc0a53cfb1
commit
6dbb26967e
|
@ -611,6 +611,12 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
|||
ScopePtr scope = node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
replace_node->set_scope(scope);
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
|
||||
if (prim->name() == EMBEDDING_LOOKUP) {
|
||||
auto attrs = prim->attrs();
|
||||
attrs[TARGET] = MakeValue(CPU);
|
||||
(void)prim->SetAttrs(attrs);
|
||||
}
|
||||
if (index == replace_op.size() - 1) {
|
||||
(void)replace_node->set_operator_info(node->operator_info());
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer
|
|||
from ..cell import Cell
|
||||
from ..._checkparam import Validator as validator
|
||||
|
||||
__all__ = ['Embedding']
|
||||
__all__ = ['Embedding', 'EmbeddingLookup']
|
||||
|
||||
class Embedding(Cell):
|
||||
r"""
|
||||
|
@ -147,7 +147,7 @@ class EmbeddingLookup(Cell):
|
|||
|
||||
def construct(self, params, indices):
|
||||
if self.target == "CPU":
|
||||
out = self.embeddinglookup(params, ids, 0)
|
||||
out = self.embeddinglookup(params, indices, 0)
|
||||
else:
|
||||
out = self.gatherv2(param, ids, 0)
|
||||
out = self.gatherv2(params, indices, 0)
|
||||
return out
|
||||
|
|
Loading…
Reference in New Issue