From ee7ae691cd054d07000fc07e6c0397aab57102d4 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Tue, 14 Feb 2023 14:45:11 +0800 Subject: [PATCH] Handle the problem that the subgraph is mistakenly eliminated, when the subgraph has side effects. --- mindspore/ccsrc/pipeline/jit/action.cc | 12 +++++++++- tests/st/auto_monad/test_auto_monad.py | 31 ++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 9fc4eace012..a0325cdcc7f 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2022 Huawei Technologies Co., Ltd + * Copyright 2019-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -507,6 +507,16 @@ bool HasIsolatedSideEffectNode(const FuncGraphPtr &func_graph) { } } } else { + // Process call function + if (isolated_node->isa()) { + auto first_input = isolated_node->cast()->input(0); + if (IsValueNode(first_input)) { + auto func = GetValueNode(first_input); + if (IsSideEffectCNode(func->output()) || HasIsolatedSideEffectNode(func)) { + return true; + } + } + } if (IsSideEffectCNode(isolated_node)) { MS_LOG(DEBUG) << "Single isolated side-effect node: " << isolated_node->DebugString(); return true; diff --git a/tests/st/auto_monad/test_auto_monad.py b/tests/st/auto_monad/test_auto_monad.py index 92c701f55f0..3d12848da2e 100644 --- a/tests/st/auto_monad/test_auto_monad.py +++ b/tests/st/auto_monad/test_auto_monad.py @@ -21,6 +21,7 @@ import numpy as np import mindspore as ms import mindspore.ops.operations as P import mindspore.nn as nn +import mindspore.common.dtype as mstype from mindspore.nn import Cell from mindspore.nn import ReLU, BatchNorm2d, Conv2d, ParameterUpdate from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits @@ -1890,3 +1891,33 @@ def test_print_in_constant_returned_func(): patterns = {'x:\n(1, 2, 3, 4, 5)'} check_output(cap.output, patterns) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_none_with_side_effect(): + """ + Feature: Support None. + Description: Support None is the output of_function with side effect. + Expectation: No exception. + """ + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.param = Parameter(Tensor([5], dtype=mstype.int32), name='name_a') + + def update_param(self, weight): # pylint: disable=R1711 + self.param = 2 * weight + return None + + def construct(self, weight): + self.update_param(weight) + return self.param + + net = Net() + input_x = Tensor([2], dtype=mstype.int32) + res = net(input_x) + assert res == 4