!48873 Handle the problem that the subgraph is eliminated, when the subgraph has side effects.

Merge pull request !48873 from Margaret_wangrui/side_effect_none
This commit is contained in:
i-robot 2023-02-16 08:11:06 +00:00 committed by Gitee
commit 45b230d727
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 42 additions and 1 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -507,6 +507,16 @@ bool HasIsolatedSideEffectNode(const FuncGraphPtr &func_graph) {
}
}
} else {
// Process call function
if (isolated_node->isa<CNode>()) {
auto first_input = isolated_node->cast<CNodePtr>()->input(0);
if (IsValueNode<FuncGraph>(first_input)) {
auto func = GetValueNode<FuncGraphPtr>(first_input);
if (IsSideEffectCNode(func->output()) || HasIsolatedSideEffectNode(func)) {
return true;
}
}
}
if (IsSideEffectCNode(isolated_node)) {
MS_LOG(DEBUG) << "Single isolated side-effect node: " << isolated_node->DebugString();
return true;

View File

@ -21,6 +21,7 @@ import numpy as np
import mindspore as ms
import mindspore.ops.operations as P
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.nn import Cell
from mindspore.nn import ReLU, BatchNorm2d, Conv2d, ParameterUpdate
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits
@ -1890,3 +1891,33 @@ def test_print_in_constant_returned_func():
patterns = {'x:\n(1, 2, 3, 4, 5)'}
check_output(cap.output, patterns)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_none_with_side_effect():
"""
Feature: Support None.
Description: Support None is the output of_function with side effect.
Expectation: No exception.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.param = Parameter(Tensor([5], dtype=mstype.int32), name='name_a')
def update_param(self, weight): # pylint: disable=R1711
self.param = 2 * weight
return None
def construct(self, weight):
self.update_param(weight)
return self.param
net = Net()
input_x = Tensor([2], dtype=mstype.int32)
res = net(input_x)
assert res == 4