diff --git a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h index 5e7ce7cdd73..f1f73de4d93 100644 --- a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h @@ -88,7 +88,9 @@ class TwoReshapeEliminater : public AnfVisitor { auto fg = node->func_graph(); if (fg != nullptr && x_ != nullptr && shape_ != nullptr) { - return fg->NewCNode({NewValueNode(prim_), x_, shape_}); + auto new_node = fg->NewCNode({NewValueNode(prim_), x_, shape_}); + new_node->set_abstract(node->abstract()); + return new_node; } return nullptr; } diff --git a/tests/ut/python/parallel/test_auto_parallel_reshape.py b/tests/ut/python/parallel/test_auto_parallel_reshape.py index 0f987ddcb03..5b71d47a44c 100644 --- a/tests/ut/python/parallel/test_auto_parallel_reshape.py +++ b/tests/ut/python/parallel/test_auto_parallel_reshape.py @@ -45,7 +45,6 @@ class GradWrap(nn.Cell): return C.grad_all(self.network)(x) -# core dump, step_auto_parallel should SetInputs for transpose axis def test_reshape_matmul(): class Net(nn.Cell): def __init__(self): @@ -68,6 +67,28 @@ def test_reshape_matmul(): net.set_auto_parallel() _executor.compile(net, x) +def test_reshape_reshape(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.relu = P.ReLU() + + def construct(self, x): + x = self.relu(x) + out = self.reshape(x, (64, 28)) + out = self.reshape(out, (64, 28, 1)) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x) + def test_reshape_auto_1(): class Net(nn.Cell):