forked from mindspore-Ecosystem/mindspore
fix reshape reshape case
This commit is contained in:
parent
2e684e89e7
commit
96c9569dca
|
@ -88,7 +88,9 @@ class TwoReshapeEliminater : public AnfVisitor {
|
||||||
|
|
||||||
auto fg = node->func_graph();
|
auto fg = node->func_graph();
|
||||||
if (fg != nullptr && x_ != nullptr && shape_ != nullptr) {
|
if (fg != nullptr && x_ != nullptr && shape_ != nullptr) {
|
||||||
return fg->NewCNode({NewValueNode(prim_), x_, shape_});
|
auto new_node = fg->NewCNode({NewValueNode(prim_), x_, shape_});
|
||||||
|
new_node->set_abstract(node->abstract());
|
||||||
|
return new_node;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,6 @@ class GradWrap(nn.Cell):
|
||||||
return C.grad_all(self.network)(x)
|
return C.grad_all(self.network)(x)
|
||||||
|
|
||||||
|
|
||||||
# core dump, step_auto_parallel should SetInputs for transpose axis
|
|
||||||
def test_reshape_matmul():
|
def test_reshape_matmul():
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -68,6 +67,28 @@ def test_reshape_matmul():
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
_executor.compile(net, x)
|
_executor.compile(net, x)
|
||||||
|
|
||||||
|
def test_reshape_reshape():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.relu(x)
|
||||||
|
out = self.reshape(x, (64, 28))
|
||||||
|
out = self.reshape(out, (64, 28, 1))
|
||||||
|
return out
|
||||||
|
|
||||||
|
size = 8
|
||||||
|
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||||
|
x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
|
||||||
|
|
||||||
|
net = GradWrap(NetWithLoss(Net()))
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
|
net.set_auto_parallel()
|
||||||
|
_executor.compile(net, x)
|
||||||
|
|
||||||
|
|
||||||
def test_reshape_auto_1():
|
def test_reshape_auto_1():
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
|
|
Loading…
Reference in New Issue