From 12e91071628a5f9589d03d84e586992c500639c1 Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Tue, 10 Nov 2020 16:57:57 +0800 Subject: [PATCH] Fix VirtualDiv Int32 error --- mindspore/ops/_grad/grad_comm_ops.py | 3 ++- .../parallel/test_unsortedsegmentsum.py | 20 +++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 12de04427c8..3f5dc390b4a 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -225,7 +225,8 @@ def get_bprop_virtual_div_operator(self): def bprop(x, out, dout): if F.issubclass_(F.typeof(dout), mstype.tensor): - if F.issubclass_(F.dtype(dout), mstype.bool_): + if F.issubclass_(F.dtype(dout), mstype.bool_) or F.issubclass_(F.dtype(dout), mstype.int32) \ + or F.issubclass_(F.dtype(dout), mstype.int16): return (dout,) dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout))) return (dx,) diff --git a/tests/ut/python/parallel/test_unsortedsegmentsum.py b/tests/ut/python/parallel/test_unsortedsegmentsum.py index 1b0d3b96826..6ea84a1467d 100644 --- a/tests/ut/python/parallel/test_unsortedsegmentsum.py +++ b/tests/ut/python/parallel/test_unsortedsegmentsum.py @@ -21,7 +21,6 @@ from mindspore import context from mindspore.common.api import _executor from mindspore.ops import composite as C from mindspore.ops import operations as P -from mindspore.ops.operations.comm_ops import _VirtualDataset from tests.ut.python.ops.test_math_ops import VirtualLoss context.set_context(mode=context.GRAPH_MODE) @@ -33,7 +32,6 @@ grad_all = C.GradOperation(get_all=True) class Net(nn.Cell): def __init__(self, strategy1, strategy2, num_segments): super(Net, self).__init__() - self.virtual_dataset = _VirtualDataset() self.merge_op = P.UnsortedSegmentSum().shard((strategy1, strategy2)) self.num_segments = num_segments @@ -54,8 +52,8 @@ class GradWrap(nn.Cell): class NetWithLoss(nn.Cell): def __init__(self, network): super(NetWithLoss, self).__init__() - self.loss = VirtualLoss() self.network = network + self.loss = VirtualLoss() def construct(self, x, y): predict = self.network(x, y) @@ -63,13 +61,13 @@ class NetWithLoss(nn.Cell): def compile_graph(x, y, segments, strategy1, strategy2, auto=False): - net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments))) - net.set_auto_parallel() - net.set_train() if auto: context.set_auto_parallel_context(parallel_mode="auto_parallel") else: context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments))) + net.set_auto_parallel() + net.set_train() _executor.compile(net, x, y) @@ -151,3 +149,13 @@ def test_unsortedsegmentsum_model_parallel_index_vector_slice_3d(): strategy1 = (2, 1, 2) strategy2 = (2, 1) compile_graph(x, y, num_segments, strategy1, strategy2) + + +def test_unsortedsegmentsum_model_parallel_repeat_caculate(): + context.set_auto_parallel_context(device_num=4, global_rank=0) + x = Tensor(np.ones((4, 4, 8)), ms.float32) + y = Tensor(np.ones((4, 4)), ms.int32) + num_segments = 16 + strategy1 = (1, 1, 1) + strategy2 = (1, 1) + compile_graph(x, y, num_segments, strategy1, strategy2)