fix the bprob error of embeddinglookup

This commit is contained in:
Xiaoda Zhang 2020-06-18 12:01:42 +08:00
parent 373832d030
commit 69574f3823
2 changed files with 38 additions and 3 deletions

View File

@ -203,7 +203,7 @@ def get_bprop_embedding_lookup(self):
actual_dout = elu_grad(dout, split_num) actual_dout = elu_grad(dout, split_num)
else: else:
actual_dout = dout actual_dout = dout
new_indices = host_sub(indices - offset) new_indices = host_sub(indices, offset)
# Reshape the 'new_indices' # Reshape the 'new_indices'
new_indices_shape_changed = (size_op(new_indices),) new_indices_shape_changed = (size_op(new_indices),)
new_indices = host_reshape(new_indices, new_indices_shape_changed) new_indices = host_reshape(new_indices, new_indices_shape_changed)
@ -211,7 +211,7 @@ def get_bprop_embedding_lookup(self):
x_shp_tail = x_shp[1:] x_shp_tail = x_shp[1:]
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed) actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
return (new_indices, actual_dout, x_shp), zeros_like(new_indices), zeros_like(axis), \ return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \
zeros_like(reduce_scatter_flag), zeros_like(split_num) zeros_like(reduce_scatter_flag), zeros_like(split_num)
return bprop_sparse return bprop_sparse

View File

@ -16,12 +16,20 @@ import numpy as np
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _inner_ops as inner
from mindspore import Tensor, context
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return C.grad_all(self.network)(x, y)
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
def __init__(self, network): def __init__(self, network):
@ -73,3 +81,30 @@ def test_embeddinglookup_reducescatter_true():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y) _executor.compile(net, x, y)
def test_embeddinglookup_reducescatter_false_grad():
shape = [8, 8]
offset = 8
reduce_scatter_flag = False
split_num = 1
net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_embeddinglookup_reducescatter_true_grad():
context.set_context(save_graphs=True)
shape = [64, 8]
offset = 8
reduce_scatter_flag = True
split_num = 8
net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
_executor.compile(net, x, y)