!49472 Fixed the bug of side effect in fallback.

Merge pull request !49472 from Margaret_wangrui/fallback_side_effect
This commit is contained in:
i-robot 2023-02-28 07:33:52 +00:00 committed by Gitee
commit 33be16d103
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 46 additions and 10 deletions

View File

@ -54,13 +54,14 @@ AnfNodePtr ConvertInterpretedObjectToPyExecute(const FuncGraphPtr &fg, const Val
// Build new CNode for value node.
ValuePtrList keys({std::make_shared<StringImm>(value_node_key)});
ValuePtrList values({std::make_shared<StringImm>(value_node_key)});
const auto interpreted_cnode = fg->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str),
NewValueNode(std::make_shared<ValueTuple>(keys)),
NewValueNode(std::make_shared<ValueTuple>(values))});
const auto interpreted_cnode = fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str),
NewValueNode(std::make_shared<ValueTuple>(keys)),
NewValueNode(std::make_shared<ValueTuple>(values))});
constexpr auto debug_recursive_level = 2;
MS_LOG(DEBUG) << "original node: " << node->DebugString(debug_recursive_level)
<< ", interpreted_cnode: " << interpreted_cnode->DebugString(debug_recursive_level);
interpreted_cnode->set_debug_info(node->debug_info());
fg->ReplaceInOrder(node, interpreted_cnode);
return interpreted_cnode;
}
} // namespace mindspore

View File

@ -373,7 +373,7 @@ EvalResultPtr AnalysisEngine::InterpretedNodeCall(const CNodePtr &cnode, const A
auto fg = cnode->func_graph();
const auto value_tuple_node = fg->NewCNode(value_list);
const auto getattr_obj_call_node = fg->NewCNodeInOrder(
const auto getattr_obj_call_node = fg->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_call_str), NewValueNode(key_tuple), value_tuple_node});
MS_LOG(DEBUG) << "getattr_obj_call_node: " << getattr_obj_call_node->DebugString();
@ -382,7 +382,7 @@ EvalResultPtr AnalysisEngine::InterpretedNodeCall(const CNodePtr &cnode, const A
AnalysisEnginePtr eng = conf->engine();
MS_EXCEPTION_IF_NULL(eng);
AnfNodeConfigPtr fn_conf = eng->MakeConfig(getattr_obj_call_node, conf->context(), conf->func_graph());
return eng->ForwardConfig(conf, fn_conf);
return eng->ForwardConfig(conf, fn_conf, false);
}
AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
@ -807,7 +807,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from " << func->type_name();
}
EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf,
bool need_erase) {
MS_EXCEPTION_IF_NULL(orig_conf);
MS_EXCEPTION_IF_NULL(new_conf);
// If always_eval_flag is true in BaseFuncGraphEvaluaotr, then the CNode with same orig_conf may be forwarded
@ -816,7 +817,7 @@ EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, c
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->ToString() << ", to new_conf: " << new_conf->ToString();
auto old_cnode = orig_conf->node()->cast_ptr<CNode>();
auto new_cnode = new_conf->node()->cast<CNodePtr>();
if (old_cnode != nullptr && new_cnode != nullptr) {
if (need_erase && old_cnode != nullptr && new_cnode != nullptr) {
if (old_cnode->func_graph() == new_cnode->func_graph()) {
MS_LOG(DEBUG) << "Try to remove forward node from order list, forward node: " << new_cnode->DebugString()
<< ", as origin node should be in order list, origin_node: " << old_cnode->DebugString();

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -322,7 +322,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// Set the analysis result for orig to the result for new.
// This sets an entry in anfnode_config_map from orig to new.
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf);
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf,
bool need_erase = true);
const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
FuncGraphPtr root_func_graph() const { return root_func_graph_; }

View File

@ -902,6 +902,7 @@ class UNet(ms.nn.Cell):
return out
@pytest.mark.skip(reason="No support PyExecute Add.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -916,7 +917,8 @@ def test_resolve_cust_class():
net = UNet(UserDefinedNet())
x = np.array([10], np.float32)
output = net(ms.Tensor(x))
print(output) # The output should == 200, but failed, check later.
print(output)
assert output == 200
@pytest.mark.level0
@ -935,3 +937,34 @@ def test_resolve_cust_ms_function_call_class():
with pytest.raises(RuntimeError) as err:
net(ms.Tensor(x))
assert "Nested execution during JIT execution is not supported." in str(err.value)
class PrintPyExecuteNet(ms.nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
def construct(self, x):
out = x * x
print("out1:", out)
out = self.net(x) + out
print("out2:", out)
return out
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_print_pyexecute():
"""
Feature: Side effect in Fallback runtime.
Description: Side effect in Fallback runtime.
Expectation: No error.
"""
net = PrintPyExecuteNet(UserDefinedNet())
x = np.array([10], np.float64)
output = net(ms.Tensor(x))
print(output)
assert output == 200