!49869 Support to return isolated node in control flow branches with raise node.

Merge pull request !49869 from Margaret_wangrui/raise_none
This commit is contained in:
i-robot 2023-03-08 03:42:19 +00:00 committed by Gitee
commit 7e8d55b533
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 136 additions and 62 deletions

View File

@ -612,6 +612,25 @@ class CleanAfterOptARewriter : public BaseRewriter {
: BaseRewriter(root_graph, manager) {}
~CleanAfterOptARewriter() override = default;
void UpdateAbstracts() {
const auto &nodes = manager_->all_nodes();
for (const auto &node : nodes) {
const auto &abs = node->abstract();
if (abs == nullptr) {
continue;
}
// Set flag for convert AbstractNone(PyExecute) to AbstractTensor in next renormalize.
if (IsPrimitiveCNode(node, prim::kPrimPyExecute) && abs->isa<abstract::AbstractNone>()) {
constexpr auto data_type = "__py_execute_no_return_type__";
if (node->has_user_data(data_type)) {
auto type = std::make_shared<TypeAnything>();
node->set_user_data<Type>(data_type, type);
set_need_renormalized(true);
}
}
}
}
protected:
// From:
// MakeSparseTensor(indices, values, dense_shape)
@ -897,6 +916,13 @@ class CleanAfterOptARewriter : public BaseRewriter {
AnfNodePtr none_execute_node =
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimPyExecute), script_node, none_tuple_node, none_tuple_node});
MS_LOG(DEBUG) << "none_execute_node:" << none_execute_node->DebugString();
// Keep AbstractNone for PyExecute, because the control flow join problem.
auto none_type = std::make_shared<TypeNone>();
none_execute_node->set_user_data<Type>("__py_execute_no_return_type__", none_type);
AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
res->set_value(kAnyValue);
none_execute_node->set_abstract(res);
return none_execute_node;
}
@ -1051,6 +1077,7 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const pipeline::ResourcePtr &resou
CleanAfterOptARewriter rewriter(root, manager);
bool change = rewriter.Execute();
// Renormalize for new PyExecute node.
rewriter.UpdateAbstracts();
if (rewriter.need_renormalized()) {
abstract::AbstractBasePtrList new_args_spec;
std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec),

View File

@ -225,6 +225,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
if (enable_eliminate_unused_element) {
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &maybe_func = engine->GetCNodeOperatorAbstract(cnode, context, fg);
if (maybe_func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
maybe_func->isa<abstract::FuncGraphAbstractClosure>()) {
@ -232,6 +233,12 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(engine, fg, cnode, abs_func_graph, context);
}
}
if (engine->check_isolated_side_effect() && node_eval_result->has_isolated_side_effect()) {
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_has_isolated_side_effect_node(true);
fg->set_has_isolated_side_effect_node(true);
}
MS_LOG(DEBUG) << "No need to jump as found result from cache for node_config";
} else {
node_eval_result = engine->ObtainEvalResultWithoutCache(node_conf);

View File

@ -2064,12 +2064,16 @@ EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abst
MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list;
// when return value should be none
if (current_interpret_node->has_user_data("__py_execute_no_return_type__")) {
AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
res->set_value(kAnyValue);
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result;
constexpr auto data_type = "__py_execute_no_return_type__";
if (current_interpret_node->has_user_data(data_type)) {
auto type = current_interpret_node->user_data<Type>(data_type);
if (type->isa<TypeNone>()) {
AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
res->set_value(kAnyValue);
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result;
}
}
TypePtr type = kFloat64;
if (current_interpret_node->has_user_data("__py_execute_tensor_type__")) {
@ -2811,6 +2815,14 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
auto none_type = std::make_shared<TypeNone>();
raise_error_node->set_user_data<Type>("__py_execute_no_return_type__", none_type);
cur_graph->ReplaceInOrder(node, raise_error_node);
// Set isolated side effect flag for raise node.
const auto &manager = cur_graph->manager();
manager->Replace(node, raise_error_node);
raise_error_node->set_has_isolated_side_effect_node(true);
cur_graph->set_has_isolated_side_effect_node(true);
MS_LOG(DEBUG) << "Found Side Effect Primitive CNode: " << raise_error_node->DebugString();
AnalysisEnginePtr eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
AnfNodeConfigPtr fn_conf = eng->MakeConfig(raise_error_node, out_conf->context(), out_conf->func_graph());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
* Copyright 2021-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -199,6 +199,13 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
node_eval_result = engine->ObtainEvalResultWithoutCache(node_conf);
} else {
node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
MS_EXCEPTION_IF_NULL(node_eval_result);
if (engine->check_isolated_side_effect() && node_eval_result->has_isolated_side_effect()) {
const auto &cnode = current_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_has_isolated_side_effect_node(true);
current_context_->func_graph()->set_has_isolated_side_effect_node(true);
}
}
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString() << ") = "
<< (node_eval_result->abstract() ? node_eval_result->abstract()->ToString() : "Abstract null");

View File

@ -695,6 +695,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
if (contains_isolated_side_effect) {
cnode->set_has_isolated_side_effect_node(true);
conf->func_graph()->set_has_isolated_side_effect_node(true);
eval_result->set_has_isolated_side_effect(true);
}
}
return eval_result;

View File

@ -59,16 +59,23 @@ using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
// the class to save evaluated result: abstract value and modified attribute
class EvalResult : public Base {
public:
EvalResult(const AbstractBasePtr &abs, const AttrValueMapPtr &attr) : abstract_(abs), attribute_(attr) {}
EvalResult(const AbstractBasePtr &abs, const AttrValueMapPtr &attr)
: abstract_(abs), attribute_(attr), has_isolated_side_effect_(false) {}
~EvalResult() override = default;
MS_DECLARE_PARENT(EvalResult, Base);
const AbstractBasePtr &abstract() const { return abstract_; }
const AttrValueMapPtr &attribute() const { return attribute_; }
bool has_isolated_side_effect() const { return has_isolated_side_effect_; }
void set_has_isolated_side_effect(bool has_isolated_side_effect) {
has_isolated_side_effect_ = has_isolated_side_effect;
}
private:
AbstractBasePtr abstract_;
// Attribute related to PrimEvaluator;
AttrValueMapPtr attribute_;
bool has_isolated_side_effect_;
};
using EvalResultPtr = std::shared_ptr<EvalResult>;

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -101,24 +101,6 @@ bool CheckAbstractScalar(const AnfNodePtr &node) {
return false;
}
bool CheckIfRaise(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
auto first = inputs[1];
auto script_node = first->cast<ValueNodePtr>();
if (script_node->value()->isa<StringImm>()) {
auto script = GetValueNode<StringImmPtr>(script_node)->value();
std::string raise_script = "raise_func";
auto idx = script.find(raise_script);
if (idx != string::npos) {
return true;
}
}
}
return false;
}
void ValidateAbstract(const AnfNodePtr &node) {
if (node == nullptr) {
MS_LOG(DEBUG) << "Node to validate is invalid";
@ -141,11 +123,6 @@ void ValidateAbstract(const AnfNodePtr &node) {
MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString();
return;
}
if (CheckIfRaise(node)) {
ShapeVector shp{abstract::Shape::kShapeRankAny};
auto abs = std::make_shared<abstract::AbstractTensor>(kFloat64, std::make_shared<abstract::Shape>(shp));
node->set_abstract(abs);
}
bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||

View File

@ -167,7 +167,6 @@ def test_none_is_default_value_of_parameter():
assert res1 is None
@pytest.mark.skip(reason="No support print x side effect.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -191,7 +190,7 @@ def test_none_is_default_value_of_parameter_2():
x = [1, 2]
y = [3, 4]
res2 = foo(x, y)
assert res2 == (3, 4)
assert res2 == [3, 4]
@pytest.mark.level0
@ -488,3 +487,66 @@ def test_none_is_return_of_sub_graph_control_flow():
data = Tensor(np.ones([2, 3]), dtype=ms.float32)
out = net(data)
assert (out.asnumpy() == data).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_none_is_return_of_sub_graph_control_flow_raise():
"""
Feature: Support None.
Description: Support None is the return of sub_graph in control flow. And Raise node is in sub_graph.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def inner_func(self, x): # pylint: disable=R1711
if x == 2:
raise ValueError("The input should not be ", x)
return None
def construct(self, x):
self.inner_func(x)
return x
net = RaiseNet()
res = net(Tensor(1))
print("res:", res)
assert res.asnumpy() == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_none_is_return_raise():
"""
Feature: Support None.
Description: Support None is the return of sub_graph in control flow. And Raise node is in sub_graph.
Expectation: No exception.
"""
def check_test(shp): # pylint: disable=R1711
if shp[0] > 5:
raise ValueError('raise value error.')
return None
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.one = Tensor(1, dtype=ms.float32)
def construct(self, x):
shp = x.shape
check_test(shp)
return x
with pytest.raises(ValueError, match="raise value error."):
np_data = np.random.randint(6, size=(6,))
data = Tensor(np_data, dtype=ms.float32)
dyn_tensor = Tensor(shape=[None], dtype=ms.float32)
net = Net()
net.set_inputs(dyn_tensor)
out = net(data)
assert out == data

View File

@ -228,10 +228,10 @@ def test_raise_with_variable_control_flow2():
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self, x, y):
def construct(self, x, y): # pylint: disable=R1711
if x == y:
raise RuntimeError(f"The input should not be {x}.")
return x
return None
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
net = RaiseNet()
@ -243,32 +243,6 @@ def test_raise_with_variable_control_flow2():
raise_info_joinedstr_tensor.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_raise_with_variable_control_flow3():
"""
Feature: graph raise by JIT Fallback.
Description: Test raise.
Expectation: No exception.
"""
class RaiseNet(nn.Cell):
def construct(self, x, y, z):
if x == y:
raise RuntimeError(f"The input should not be {x}.")
return z
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
net = RaiseNet()
x = Tensor(1)
y = Tensor(1)
z = (x, y)
res = net(x, y, z)
print("res:", res)
assert "The input should not be 1" in str(
raise_info_joinedstr_tensor.value)
@pytest.mark.skip(reason="not support yet")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu