forked from mindspore-Ecosystem/mindspore
[Control-flow] Fix a bug in parameter Assign eliminating
This commit is contained in:
parent
61b445ff41
commit
d22e521a7f
|
@ -1185,8 +1185,19 @@ class ExecuteOrderGenerator {
|
||||||
MS_EXCEPTION_IF_NULL(target);
|
MS_EXCEPTION_IF_NULL(target);
|
||||||
auto para = param_write_times.find(target);
|
auto para = param_write_times.find(target);
|
||||||
if (para != param_write_times.end() && para->second == 1) {
|
if (para != param_write_times.end() && para->second == 1) {
|
||||||
// If target only write once, replace target with source and erase assign node.
|
// Check source of the Assign.
|
||||||
auto &source = node->inputs().at(kAssignSourceIndex);
|
auto &source = node->inputs().at(kAssignSourceIndex);
|
||||||
|
MS_EXCEPTION_IF_NULL(source);
|
||||||
|
if (source->isa<Parameter>()) {
|
||||||
|
auto it = param_write_times.find(source);
|
||||||
|
if (it != param_write_times.end() && it->second > 0) {
|
||||||
|
// Skip if Assign source is a parameter and be written in other place.
|
||||||
|
++iter;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If target only write once, and source not be written,
|
||||||
|
// replace target with source and erase the Assign node.
|
||||||
auto kg = target->func_graph()->cast<KernelGraphPtr>();
|
auto kg = target->func_graph()->cast<KernelGraphPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(kg);
|
MS_EXCEPTION_IF_NULL(kg);
|
||||||
kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source));
|
kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source));
|
||||||
|
|
|
@ -1429,6 +1429,33 @@ def test_if_cast():
|
||||||
np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy())
|
np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_while_forward():
|
||||||
|
class MyWhileNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.max = P.ReduceMax()
|
||||||
|
|
||||||
|
def construct(self, idx, end, x):
|
||||||
|
while idx < end:
|
||||||
|
part = x[idx, :, :]
|
||||||
|
max_num = self.max(part)
|
||||||
|
x[idx, :, 0:2] = max_num
|
||||||
|
idx = idx + 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
net = MyWhileNet()
|
||||||
|
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||||
|
end = Tensor(np.array(2), dtype=ms.int32)
|
||||||
|
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||||
|
output = net(idx, end, x)
|
||||||
|
expect = np.array([[[3, 3], [3, 3]], [[7, 7], [7, 7]]], dtype=np.int32)
|
||||||
|
assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="not supported yet")
|
@pytest.mark.skip(reason="not supported yet")
|
||||||
def test_multi_add_assign():
|
def test_multi_add_assign():
|
||||||
class Net(Cell):
|
class Net(Cell):
|
||||||
|
|
Loading…
Reference in New Issue