forked from mindspore-Ecosystem/mindspore
fix the bprob error of embeddinglookup
This commit is contained in:
parent
373832d030
commit
69574f3823
|
@ -203,7 +203,7 @@ def get_bprop_embedding_lookup(self):
|
||||||
actual_dout = elu_grad(dout, split_num)
|
actual_dout = elu_grad(dout, split_num)
|
||||||
else:
|
else:
|
||||||
actual_dout = dout
|
actual_dout = dout
|
||||||
new_indices = host_sub(indices - offset)
|
new_indices = host_sub(indices, offset)
|
||||||
# Reshape the 'new_indices'
|
# Reshape the 'new_indices'
|
||||||
new_indices_shape_changed = (size_op(new_indices),)
|
new_indices_shape_changed = (size_op(new_indices),)
|
||||||
new_indices = host_reshape(new_indices, new_indices_shape_changed)
|
new_indices = host_reshape(new_indices, new_indices_shape_changed)
|
||||||
|
@ -211,7 +211,7 @@ def get_bprop_embedding_lookup(self):
|
||||||
x_shp_tail = x_shp[1:]
|
x_shp_tail = x_shp[1:]
|
||||||
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
|
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
|
||||||
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
|
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
|
||||||
return (new_indices, actual_dout, x_shp), zeros_like(new_indices), zeros_like(axis), \
|
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \
|
||||||
zeros_like(reduce_scatter_flag), zeros_like(split_num)
|
zeros_like(reduce_scatter_flag), zeros_like(split_num)
|
||||||
return bprop_sparse
|
return bprop_sparse
|
||||||
|
|
||||||
|
|
|
@ -16,12 +16,20 @@ import numpy as np
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore.common.api import _executor
|
from mindspore.common.api import _executor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops.operations import _inner_ops as inner
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
|
from mindspore import Tensor, context
|
||||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||||
|
|
||||||
|
class GradWrap(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradWrap, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return C.grad_all(self.network)(x, y)
|
||||||
|
|
||||||
class NetWithLoss(nn.Cell):
|
class NetWithLoss(nn.Cell):
|
||||||
def __init__(self, network):
|
def __init__(self, network):
|
||||||
|
@ -73,3 +81,30 @@ def test_embeddinglookup_reducescatter_true():
|
||||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embeddinglookup_reducescatter_false_grad():
|
||||||
|
shape = [8, 8]
|
||||||
|
offset = 8
|
||||||
|
reduce_scatter_flag = False
|
||||||
|
split_num = 1
|
||||||
|
net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embeddinglookup_reducescatter_true_grad():
|
||||||
|
context.set_context(save_graphs=True)
|
||||||
|
shape = [64, 8]
|
||||||
|
offset = 8
|
||||||
|
reduce_scatter_flag = True
|
||||||
|
split_num = 8
|
||||||
|
net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
Loading…
Reference in New Issue