Fix VirtualDiv Int32 error

This commit is contained in:
huangxinjing 2020-11-10 16:57:57 +08:00
parent d70f25edc0
commit 12e9107162
2 changed files with 16 additions and 7 deletions

View File

@ -225,7 +225,8 @@ def get_bprop_virtual_div_operator(self):
def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):
if F.issubclass_(F.dtype(dout), mstype.bool_):
if F.issubclass_(F.dtype(dout), mstype.bool_) or F.issubclass_(F.dtype(dout), mstype.int32) \
or F.issubclass_(F.dtype(dout), mstype.int16):
return (dout,)
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
return (dx,)

View File

@ -21,7 +21,6 @@ from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset
from tests.ut.python.ops.test_math_ops import VirtualLoss
context.set_context(mode=context.GRAPH_MODE)
@ -33,7 +32,6 @@ grad_all = C.GradOperation(get_all=True)
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, num_segments):
super(Net, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.merge_op = P.UnsortedSegmentSum().shard((strategy1, strategy2))
self.num_segments = num_segments
@ -54,8 +52,8 @@ class GradWrap(nn.Cell):
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
self.loss = VirtualLoss()
def construct(self, x, y):
predict = self.network(x, y)
@ -63,13 +61,13 @@ class NetWithLoss(nn.Cell):
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
net.set_auto_parallel()
net.set_train()
if auto:
context.set_auto_parallel_context(parallel_mode="auto_parallel")
else:
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y)
@ -151,3 +149,13 @@ def test_unsortedsegmentsum_model_parallel_index_vector_slice_3d():
strategy1 = (2, 1, 2)
strategy2 = (2, 1)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentsum_model_parallel_repeat_caculate():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.float32)
y = Tensor(np.ones((4, 4)), ms.int32)
num_segments = 16
strategy1 = (1, 1, 1)
strategy2 = (1, 1)
compile_graph(x, y, num_segments, strategy1, strategy2)