forked from mindspore-Ecosystem/mindspore
fix the bprob error of embeddinglookup
This commit is contained in:
parent
373832d030
commit
69574f3823
|
@ -203,7 +203,7 @@ def get_bprop_embedding_lookup(self):
|
|||
actual_dout = elu_grad(dout, split_num)
|
||||
else:
|
||||
actual_dout = dout
|
||||
new_indices = host_sub(indices - offset)
|
||||
new_indices = host_sub(indices, offset)
|
||||
# Reshape the 'new_indices'
|
||||
new_indices_shape_changed = (size_op(new_indices),)
|
||||
new_indices = host_reshape(new_indices, new_indices_shape_changed)
|
||||
|
@ -211,7 +211,7 @@ def get_bprop_embedding_lookup(self):
|
|||
x_shp_tail = x_shp[1:]
|
||||
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
|
||||
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
|
||||
return (new_indices, actual_dout, x_shp), zeros_like(new_indices), zeros_like(axis), \
|
||||
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \
|
||||
zeros_like(reduce_scatter_flag), zeros_like(split_num)
|
||||
return bprop_sparse
|
||||
|
||||
|
|
|
@ -16,12 +16,20 @@ import numpy as np
|
|||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore import Tensor, context
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return C.grad_all(self.network)(x, y)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
|
@ -73,3 +81,30 @@ def test_embeddinglookup_reducescatter_true():
|
|||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_embeddinglookup_reducescatter_false_grad():
|
||||
shape = [8, 8]
|
||||
offset = 8
|
||||
reduce_scatter_flag = False
|
||||
split_num = 1
|
||||
net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_embeddinglookup_reducescatter_true_grad():
|
||||
context.set_context(save_graphs=True)
|
||||
shape = [64, 8]
|
||||
offset = 8
|
||||
reduce_scatter_flag = True
|
||||
split_num = 8
|
||||
net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
||||
_executor.compile(net, x, y)
|
||||
|
|
Loading…
Reference in New Issue