forked from mindspore-Ecosystem/mindspore
keep the sub of in grad of embeddinglookup on device
This commit is contained in:
parent
f3044f0034
commit
edbb5d9608
|
@ -194,23 +194,26 @@ def get_bprop_tile(self):
|
|||
@bprop_getters.register(inner.EmbeddingLookup)
|
||||
def get_bprop_embedding_lookup(self):
|
||||
"""Generate bprop for EmbeddingLookup"""
|
||||
host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU')
|
||||
sub_op = P.Sub()
|
||||
reshape_op = P.Reshape()
|
||||
host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU')
|
||||
def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout):
|
||||
x_shp = shape_op(x)
|
||||
if reduce_scatter_flag is True:
|
||||
elu_grad = G.EmbeddingLookupCommGrad()
|
||||
actual_dout = elu_grad(dout, split_num)
|
||||
else:
|
||||
actual_dout = dout
|
||||
new_indices = host_sub(indices, offset)
|
||||
new_indices = sub_op(indices, offset)
|
||||
# Reshape the 'new_indices'
|
||||
new_indices_shape_changed = (size_op(new_indices),)
|
||||
new_indices = host_reshape(new_indices, new_indices_shape_changed)
|
||||
# Reshape the 'actual_dout'
|
||||
new_indices = reshape_op(new_indices, new_indices_shape_changed)
|
||||
x_shp_tail = x_shp[1:]
|
||||
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
|
||||
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
|
||||
if reduce_scatter_flag is True:
|
||||
# On host
|
||||
elu_grad = G.EmbeddingLookupCommGrad()
|
||||
actual_dout = elu_grad(dout, split_num)
|
||||
# Reshape the 'actual_dout' on host
|
||||
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
|
||||
else:
|
||||
# Reshape the 'actual_dout' on device
|
||||
actual_dout = reshape_op(dout, actual_dout_shape_changed)
|
||||
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \
|
||||
zeros_like(reduce_scatter_flag), zeros_like(split_num)
|
||||
return bprop_sparse
|
||||
|
|
Loading…
Reference in New Issue