forked from mindspore-Ecosystem/mindspore
[Control-flow] Fix a bug in parameter Assign eliminating
This commit is contained in:
parent
61b445ff41
commit
d22e521a7f
|
@ -1185,8 +1185,19 @@ class ExecuteOrderGenerator {
|
|||
MS_EXCEPTION_IF_NULL(target);
|
||||
auto para = param_write_times.find(target);
|
||||
if (para != param_write_times.end() && para->second == 1) {
|
||||
// If target only write once, replace target with source and erase assign node.
|
||||
// Check source of the Assign.
|
||||
auto &source = node->inputs().at(kAssignSourceIndex);
|
||||
MS_EXCEPTION_IF_NULL(source);
|
||||
if (source->isa<Parameter>()) {
|
||||
auto it = param_write_times.find(source);
|
||||
if (it != param_write_times.end() && it->second > 0) {
|
||||
// Skip if Assign source is a parameter and be written in other place.
|
||||
++iter;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// If target only write once, and source not be written,
|
||||
// replace target with source and erase the Assign node.
|
||||
auto kg = target->func_graph()->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kg);
|
||||
kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source));
|
||||
|
|
|
@ -1429,6 +1429,33 @@ def test_if_cast():
|
|||
np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_forward():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
while idx < end:
|
||||
part = x[idx, :, :]
|
||||
max_num = self.max(part)
|
||||
x[idx, :, 0:2] = max_num
|
||||
idx = idx + 1
|
||||
return x
|
||||
|
||||
net = MyWhileNet()
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
output = net(idx, end, x)
|
||||
expect = np.array([[[3, 3], [3, 3]], [[7, 7], [7, 7]]], dtype=np.int32)
|
||||
assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not supported yet")
|
||||
def test_multi_add_assign():
|
||||
class Net(Cell):
|
||||
|
|
Loading…
Reference in New Issue