forked from mindspore-Ecosystem/mindspore
Fix VirtualDiv Int32 error
This commit is contained in:
parent
d70f25edc0
commit
12e9107162
|
@ -225,7 +225,8 @@ def get_bprop_virtual_div_operator(self):
|
|||
|
||||
def bprop(x, out, dout):
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
if F.issubclass_(F.dtype(dout), mstype.bool_):
|
||||
if F.issubclass_(F.dtype(dout), mstype.bool_) or F.issubclass_(F.dtype(dout), mstype.int32) \
|
||||
or F.issubclass_(F.dtype(dout), mstype.int16):
|
||||
return (dout,)
|
||||
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
|
||||
return (dx,)
|
||||
|
|
|
@ -21,7 +21,6 @@ from mindspore import context
|
|||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -33,7 +32,6 @@ grad_all = C.GradOperation(get_all=True)
|
|||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, num_segments):
|
||||
super(Net, self).__init__()
|
||||
self.virtual_dataset = _VirtualDataset()
|
||||
self.merge_op = P.UnsortedSegmentSum().shard((strategy1, strategy2))
|
||||
self.num_segments = num_segments
|
||||
|
||||
|
@ -54,8 +52,8 @@ class GradWrap(nn.Cell):
|
|||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
self.loss = VirtualLoss()
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
|
@ -63,13 +61,13 @@ class NetWithLoss(nn.Cell):
|
|||
|
||||
|
||||
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
if auto:
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
else:
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
|
@ -151,3 +149,13 @@ def test_unsortedsegmentsum_model_parallel_index_vector_slice_3d():
|
|||
strategy1 = (2, 1, 2)
|
||||
strategy2 = (2, 1)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentsum_model_parallel_repeat_caculate():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 4, 8)), ms.float32)
|
||||
y = Tensor(np.ones((4, 4)), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (1, 1, 1)
|
||||
strategy2 = (1, 1)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
|
Loading…
Reference in New Issue