forked from mindspore-Ecosystem/mindspore
!49472 Fixed the bug of side effect in fallback.
Merge pull request !49472 from Margaret_wangrui/fallback_side_effect
This commit is contained in:
commit
33be16d103
|
@ -54,13 +54,14 @@ AnfNodePtr ConvertInterpretedObjectToPyExecute(const FuncGraphPtr &fg, const Val
|
|||
// Build new CNode for value node.
|
||||
ValuePtrList keys({std::make_shared<StringImm>(value_node_key)});
|
||||
ValuePtrList values({std::make_shared<StringImm>(value_node_key)});
|
||||
const auto interpreted_cnode = fg->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str),
|
||||
NewValueNode(std::make_shared<ValueTuple>(keys)),
|
||||
NewValueNode(std::make_shared<ValueTuple>(values))});
|
||||
const auto interpreted_cnode = fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str),
|
||||
NewValueNode(std::make_shared<ValueTuple>(keys)),
|
||||
NewValueNode(std::make_shared<ValueTuple>(values))});
|
||||
constexpr auto debug_recursive_level = 2;
|
||||
MS_LOG(DEBUG) << "original node: " << node->DebugString(debug_recursive_level)
|
||||
<< ", interpreted_cnode: " << interpreted_cnode->DebugString(debug_recursive_level);
|
||||
interpreted_cnode->set_debug_info(node->debug_info());
|
||||
fg->ReplaceInOrder(node, interpreted_cnode);
|
||||
return interpreted_cnode;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -373,7 +373,7 @@ EvalResultPtr AnalysisEngine::InterpretedNodeCall(const CNodePtr &cnode, const A
|
|||
auto fg = cnode->func_graph();
|
||||
const auto value_tuple_node = fg->NewCNode(value_list);
|
||||
|
||||
const auto getattr_obj_call_node = fg->NewCNodeInOrder(
|
||||
const auto getattr_obj_call_node = fg->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_call_str), NewValueNode(key_tuple), value_tuple_node});
|
||||
MS_LOG(DEBUG) << "getattr_obj_call_node: " << getattr_obj_call_node->DebugString();
|
||||
|
||||
|
@ -382,7 +382,7 @@ EvalResultPtr AnalysisEngine::InterpretedNodeCall(const CNodePtr &cnode, const A
|
|||
AnalysisEnginePtr eng = conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(getattr_obj_call_node, conf->context(), conf->func_graph());
|
||||
return eng->ForwardConfig(conf, fn_conf);
|
||||
return eng->ForwardConfig(conf, fn_conf, false);
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisEngine::GetCNodeOperatorAbstract(const CNodePtr &cnode, const AnalysisContextPtr &context,
|
||||
|
@ -807,7 +807,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from " << func->type_name();
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
|
||||
EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf,
|
||||
bool need_erase) {
|
||||
MS_EXCEPTION_IF_NULL(orig_conf);
|
||||
MS_EXCEPTION_IF_NULL(new_conf);
|
||||
// If always_eval_flag is true in BaseFuncGraphEvaluaotr, then the CNode with same orig_conf may be forwarded
|
||||
|
@ -816,7 +817,7 @@ EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, c
|
|||
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->ToString() << ", to new_conf: " << new_conf->ToString();
|
||||
auto old_cnode = orig_conf->node()->cast_ptr<CNode>();
|
||||
auto new_cnode = new_conf->node()->cast<CNodePtr>();
|
||||
if (old_cnode != nullptr && new_cnode != nullptr) {
|
||||
if (need_erase && old_cnode != nullptr && new_cnode != nullptr) {
|
||||
if (old_cnode->func_graph() == new_cnode->func_graph()) {
|
||||
MS_LOG(DEBUG) << "Try to remove forward node from order list, forward node: " << new_cnode->DebugString()
|
||||
<< ", as origin node should be in order list, origin_node: " << old_cnode->DebugString();
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -322,7 +322,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
|
||||
// Set the analysis result for orig to the result for new.
|
||||
// This sets an entry in anfnode_config_map from orig to new.
|
||||
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf);
|
||||
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf,
|
||||
bool need_erase = true);
|
||||
const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
|
||||
|
||||
FuncGraphPtr root_func_graph() const { return root_func_graph_; }
|
||||
|
|
|
@ -902,6 +902,7 @@ class UNet(ms.nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support PyExecute Add.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -916,7 +917,8 @@ def test_resolve_cust_class():
|
|||
net = UNet(UserDefinedNet())
|
||||
x = np.array([10], np.float32)
|
||||
output = net(ms.Tensor(x))
|
||||
print(output) # The output should == 200, but failed, check later.
|
||||
print(output)
|
||||
assert output == 200
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -935,3 +937,34 @@ def test_resolve_cust_ms_function_call_class():
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
net(ms.Tensor(x))
|
||||
assert "Nested execution during JIT execution is not supported." in str(err.value)
|
||||
|
||||
|
||||
class PrintPyExecuteNet(ms.nn.Cell):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
out = x * x
|
||||
print("out1:", out)
|
||||
out = self.net(x) + out
|
||||
print("out2:", out)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_pyexecute():
|
||||
"""
|
||||
Feature: Side effect in Fallback runtime.
|
||||
Description: Side effect in Fallback runtime.
|
||||
Expectation: No error.
|
||||
"""
|
||||
net = PrintPyExecuteNet(UserDefinedNet())
|
||||
x = np.array([10], np.float64)
|
||||
output = net(ms.Tensor(x))
|
||||
print(output)
|
||||
assert output == 200
|
||||
|
|
Loading…
Reference in New Issue