From edbb5d960877d3bde266463de546565912e8e2a1 Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Sat, 20 Jun 2020 10:37:28 +0800 Subject: [PATCH] keep the sub of in grad of embeddinglookup on device --- mindspore/ops/_grad/grad_array_ops.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 923a8783b31..190bb4643bb 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -194,23 +194,26 @@ def get_bprop_tile(self): @bprop_getters.register(inner.EmbeddingLookup) def get_bprop_embedding_lookup(self): """Generate bprop for EmbeddingLookup""" - host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU') + sub_op = P.Sub() + reshape_op = P.Reshape() host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU') def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout): x_shp = shape_op(x) - if reduce_scatter_flag is True: - elu_grad = G.EmbeddingLookupCommGrad() - actual_dout = elu_grad(dout, split_num) - else: - actual_dout = dout - new_indices = host_sub(indices, offset) + new_indices = sub_op(indices, offset) # Reshape the 'new_indices' new_indices_shape_changed = (size_op(new_indices),) - new_indices = host_reshape(new_indices, new_indices_shape_changed) - # Reshape the 'actual_dout' + new_indices = reshape_op(new_indices, new_indices_shape_changed) x_shp_tail = x_shp[1:] actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail - actual_dout = host_reshape(actual_dout, actual_dout_shape_changed) + if reduce_scatter_flag is True: + # On host + elu_grad = G.EmbeddingLookupCommGrad() + actual_dout = elu_grad(dout, split_num) + # Reshape the 'actual_dout' on host + actual_dout = host_reshape(actual_dout, actual_dout_shape_changed) + else: + # Reshape the 'actual_dout' on device + actual_dout = reshape_op(dout, actual_dout_shape_changed) return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \ zeros_like(reduce_scatter_flag), zeros_like(split_num) return bprop_sparse