!13786 [Control-flow] Fix a bug in parameter Assign eliminating

From: @hwhewei
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-03-23 09:04:47 +08:00 committed by Gitee
commit 1d56895ece
2 changed files with 39 additions and 1 deletions

View File

@ -1185,8 +1185,19 @@ class ExecuteOrderGenerator {
MS_EXCEPTION_IF_NULL(target);
auto para = param_write_times.find(target);
if (para != param_write_times.end() && para->second == 1) {
// If target only write once, replace target with source and erase assign node.
// Check source of the Assign.
auto &source = node->inputs().at(kAssignSourceIndex);
MS_EXCEPTION_IF_NULL(source);
if (source->isa<Parameter>()) {
auto it = param_write_times.find(source);
if (it != param_write_times.end() && it->second > 0) {
// Skip if Assign source is a parameter and be written in other place.
++iter;
continue;
}
}
// If target only write once, and source not be written,
// replace target with source and erase the Assign node.
auto kg = target->func_graph()->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kg);
kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source));

View File

@ -1429,6 +1429,33 @@ def test_if_cast():
np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_while_forward():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
def construct(self, idx, end, x):
while idx < end:
part = x[idx, :, :]
max_num = self.max(part)
x[idx, :, 0:2] = max_num
idx = idx + 1
return x
net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
output = net(idx, end, x)
expect = np.array([[[3, 3], [3, 3]], [[7, 7], [7, 7]]], dtype=np.int32)
assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001)
@pytest.mark.skip(reason="not supported yet")
def test_multi_add_assign():
class Net(Cell):