!45709 Keep isolated side effect node created by new primitive class in the constant returned func graph.

Merge pull request !45709 from 张清华/opt_isolated_side_effect
This commit is contained in:
i-robot 2022-11-18 14:39:51 +00:00 committed by Gitee
commit 7f113b13fe
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 73 additions and 30 deletions

View File

@ -39,60 +39,60 @@ using mindspore::abstract::AbstractListPtr;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
// slice a tensor
// args: tensor, slice or slice tuple
size_t arg_length = args_spec_list.size();
size_t arg_length = args_abs_list.size();
const size_t min_args_size = 2;
if (arg_length < min_args_size) {
MS_LOG(EXCEPTION) << "The UnpackCall operator requires at least two arguments, but got " << arg_length << ".";
}
// No need to check, check will be done in infer.
auto ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret_graph->debug_info()->set_name("UnpackCall");
auto res_graph = std::make_shared<FuncGraph>();
res_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
res_graph->debug_info()->set_name("UnpackCall");
AnfNodePtr fn_node = ret_graph->add_parameter();
AnfNodePtr fn_node = res_graph->add_parameter();
std::vector<AnfNodePtr> elems;
elems.push_back(fn_node);
for (size_t index = 1; index < arg_length; index++) {
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
if (args_spec_list[index]->isa<AbstractTuple>()) {
auto arg_tuple = args_spec_list[index]->cast<AbstractTuplePtr>();
AnfNodePtr para_tuple = ret_graph->add_parameter();
MS_EXCEPTION_IF_NULL(args_abs_list[index]);
if (args_abs_list[index]->isa<AbstractTuple>()) {
auto arg_tuple = args_abs_list[index]->cast<AbstractTuplePtr>();
AnfNodePtr para_tuple = res_graph->add_parameter();
for (size_t i = 0; i < arg_tuple->size(); ++i) {
elems.push_back(
ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToLong(i))}));
res_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToLong(i))}));
}
} else if (args_spec_list[index]->isa<AbstractList>()) {
auto arg_list = args_spec_list[index]->cast<AbstractListPtr>();
AnfNodePtr para_list = ret_graph->add_parameter();
} else if (args_abs_list[index]->isa<AbstractList>()) {
auto arg_list = args_abs_list[index]->cast<AbstractListPtr>();
AnfNodePtr para_list = res_graph->add_parameter();
for (size_t i = 0; i < arg_list->size(); ++i) {
elems.push_back(
ret_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), para_list, NewValueNode(SizeToLong(i))}));
res_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), para_list, NewValueNode(SizeToLong(i))}));
}
} else if (args_spec_list[index]->isa<AbstractDictionary>()) {
AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast<AbstractDictionaryPtr>();
AnfNodePtr para_dict = ret_graph->add_parameter();
} else if (args_abs_list[index]->isa<AbstractDictionary>()) {
AbstractDictionaryPtr arg_dict = args_abs_list[index]->cast<AbstractDictionaryPtr>();
AnfNodePtr para_dict = res_graph->add_parameter();
auto dict_elems = arg_dict->elements();
(void)std::transform(
dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(elems),
[ret_graph, para_dict](const AbstractAttribute &item) {
[res_graph, para_dict](const AbstractAttribute &item) {
// Dict_elems's first element represents parameter names, which should be string type.
auto key_value = GetValue<std::string>(item.first->BuildValue());
auto dict_get_item =
ret_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(key_value)});
return ret_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(key_value), dict_get_item});
res_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(key_value)});
return res_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(key_value), dict_get_item});
});
} else {
MS_LOG(EXCEPTION) << "The arguments of UnpackCall operator should be tuple, list or dict, but got "
<< args_spec_list[index]->ToString();
<< args_abs_list[index]->ToString();
}
}
// Add to order list to trace if fn_node had side effect.
ret_graph->set_output(ret_graph->NewCNodeInOrder(elems));
return ret_graph;
res_graph->set_output(res_graph->NewCNodeInOrder(elems));
return res_graph;
}
} // namespace prim
} // namespace mindspore

View File

@ -2023,17 +2023,33 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
// Process the object.
MS_EXCEPTION_IF_NULL(out_conf->node());
TraceGuard guard(std::make_shared<TraceResolve>(out_conf->node()->debug_info()));
ValuePtr converted_ret = nullptr;
bool converted = parse::ConvertData(obj, &converted_ret, true);
ValuePtr converted_res = nullptr;
bool converted = parse::ConvertData(obj, &converted_res, true);
if (!converted) {
MS_LOG(EXCEPTION) << "Convert the python object failed";
}
MS_EXCEPTION_IF_NULL(converted_ret);
if (converted_ret->isa<FuncGraph>()) {
AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
MS_EXCEPTION_IF_NULL(converted_res);
// To check isolated side effect for the func graph who returns constant.
if (engine->check_isolated_side_effect()) {
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", converted_res: " << converted_res->ToString();
auto prim = GetValueWithoutDoSignature(converted_res)->cast<PrimitivePtr>();
if (prim != nullptr) {
auto effect_info = GetPrimEffectInfo(prim);
if (effect_info.memory || effect_info.io) {
MS_LOG(INFO) << "Found Side Effect Primitive CNode: " << out_conf->node()->DebugString();
const auto &cnode = dyn_cast<CNode>(out_conf->node());
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_has_isolated_side_effect_node(true);
out_conf->func_graph()->set_has_isolated_side_effect_node(true);
}
}
}
AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
if (converted_res->isa<FuncGraph>()) {
AddToManager(engine, converted_res->cast<FuncGraphPtr>());
}
AbstractBasePtr ret = ToAbstract(converted_res, AnalysisContext::DummyContext(), out_conf);
auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
return infer_result;

View File

@ -1863,3 +1863,30 @@ def test_print_assign_print():
'param_3:\nTensor(shape=[], dtype=Int32, value=3)\n\n'}
check_output(cap.output, patterns)
np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_in_constant_returned_func():
"""
Feature: Auto Monad
Description: Test print in a func graph who returns constant.
Expectation: No exception.
"""
class Print(Cell):
def construct(self):
x = tuple((1, 2, 3, 4, 5))
print("x:", x)
return x
cap = Capture()
with capture(cap):
net = Print()
net()
sys.stdout.flush()
time.sleep(0.1)
patterns = {'x:\n(1, 2, 3, 4, 5)'}
check_output(cap.output, patterns)