forked from mindspore-Ecosystem/mindspore
!48873 Handle the problem that the subgraph is eliminated, when the subgraph has side effects.
Merge pull request !48873 from Margaret_wangrui/side_effect_none
This commit is contained in:
commit
45b230d727
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -507,6 +507,16 @@ bool HasIsolatedSideEffectNode(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
// Process call function
|
||||
if (isolated_node->isa<CNode>()) {
|
||||
auto first_input = isolated_node->cast<CNodePtr>()->input(0);
|
||||
if (IsValueNode<FuncGraph>(first_input)) {
|
||||
auto func = GetValueNode<FuncGraphPtr>(first_input);
|
||||
if (IsSideEffectCNode(func->output()) || HasIsolatedSideEffectNode(func)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (IsSideEffectCNode(isolated_node)) {
|
||||
MS_LOG(DEBUG) << "Single isolated side-effect node: " << isolated_node->DebugString();
|
||||
return true;
|
||||
|
|
|
@ -21,6 +21,7 @@ import numpy as np
|
|||
import mindspore as ms
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn import ReLU, BatchNorm2d, Conv2d, ParameterUpdate
|
||||
from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits
|
||||
|
@ -1890,3 +1891,33 @@ def test_print_in_constant_returned_func():
|
|||
|
||||
patterns = {'x:\n(1, 2, 3, 4, 5)'}
|
||||
check_output(cap.output, patterns)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_none_with_side_effect():
|
||||
"""
|
||||
Feature: Support None.
|
||||
Description: Support None is the output of_function with side effect.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = Parameter(Tensor([5], dtype=mstype.int32), name='name_a')
|
||||
|
||||
def update_param(self, weight): # pylint: disable=R1711
|
||||
self.param = 2 * weight
|
||||
return None
|
||||
|
||||
def construct(self, weight):
|
||||
self.update_param(weight)
|
||||
return self.param
|
||||
|
||||
net = Net()
|
||||
input_x = Tensor([2], dtype=mstype.int32)
|
||||
res = net(input_x)
|
||||
assert res == 4
|
||||
|
|
Loading…
Reference in New Issue