!45709 Keep isolated side effect node created by new primitive class in the constant returned func graph.
Merge pull request !45709 from 张清华/opt_isolated_side_effect
This commit is contained in:
commit
7f113b13fe
|
@ -39,60 +39,60 @@ using mindspore::abstract::AbstractListPtr;
|
|||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractTuplePtr;
|
||||
|
||||
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
|
||||
// slice a tensor
|
||||
// args: tensor, slice or slice tuple
|
||||
size_t arg_length = args_spec_list.size();
|
||||
size_t arg_length = args_abs_list.size();
|
||||
const size_t min_args_size = 2;
|
||||
if (arg_length < min_args_size) {
|
||||
MS_LOG(EXCEPTION) << "The UnpackCall operator requires at least two arguments, but got " << arg_length << ".";
|
||||
}
|
||||
|
||||
// No need to check, check will be done in infer.
|
||||
auto ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret_graph->debug_info()->set_name("UnpackCall");
|
||||
auto res_graph = std::make_shared<FuncGraph>();
|
||||
res_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
res_graph->debug_info()->set_name("UnpackCall");
|
||||
|
||||
AnfNodePtr fn_node = ret_graph->add_parameter();
|
||||
AnfNodePtr fn_node = res_graph->add_parameter();
|
||||
std::vector<AnfNodePtr> elems;
|
||||
elems.push_back(fn_node);
|
||||
for (size_t index = 1; index < arg_length; index++) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
|
||||
if (args_spec_list[index]->isa<AbstractTuple>()) {
|
||||
auto arg_tuple = args_spec_list[index]->cast<AbstractTuplePtr>();
|
||||
AnfNodePtr para_tuple = ret_graph->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(args_abs_list[index]);
|
||||
if (args_abs_list[index]->isa<AbstractTuple>()) {
|
||||
auto arg_tuple = args_abs_list[index]->cast<AbstractTuplePtr>();
|
||||
AnfNodePtr para_tuple = res_graph->add_parameter();
|
||||
for (size_t i = 0; i < arg_tuple->size(); ++i) {
|
||||
elems.push_back(
|
||||
ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToLong(i))}));
|
||||
res_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToLong(i))}));
|
||||
}
|
||||
} else if (args_spec_list[index]->isa<AbstractList>()) {
|
||||
auto arg_list = args_spec_list[index]->cast<AbstractListPtr>();
|
||||
AnfNodePtr para_list = ret_graph->add_parameter();
|
||||
} else if (args_abs_list[index]->isa<AbstractList>()) {
|
||||
auto arg_list = args_abs_list[index]->cast<AbstractListPtr>();
|
||||
AnfNodePtr para_list = res_graph->add_parameter();
|
||||
for (size_t i = 0; i < arg_list->size(); ++i) {
|
||||
elems.push_back(
|
||||
ret_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), para_list, NewValueNode(SizeToLong(i))}));
|
||||
res_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), para_list, NewValueNode(SizeToLong(i))}));
|
||||
}
|
||||
} else if (args_spec_list[index]->isa<AbstractDictionary>()) {
|
||||
AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast<AbstractDictionaryPtr>();
|
||||
AnfNodePtr para_dict = ret_graph->add_parameter();
|
||||
} else if (args_abs_list[index]->isa<AbstractDictionary>()) {
|
||||
AbstractDictionaryPtr arg_dict = args_abs_list[index]->cast<AbstractDictionaryPtr>();
|
||||
AnfNodePtr para_dict = res_graph->add_parameter();
|
||||
auto dict_elems = arg_dict->elements();
|
||||
(void)std::transform(
|
||||
dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(elems),
|
||||
[ret_graph, para_dict](const AbstractAttribute &item) {
|
||||
[res_graph, para_dict](const AbstractAttribute &item) {
|
||||
// Dict_elems's first element represents parameter names, which should be string type.
|
||||
auto key_value = GetValue<std::string>(item.first->BuildValue());
|
||||
auto dict_get_item =
|
||||
ret_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(key_value)});
|
||||
return ret_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(key_value), dict_get_item});
|
||||
res_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(key_value)});
|
||||
return res_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(key_value), dict_get_item});
|
||||
});
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The arguments of UnpackCall operator should be tuple, list or dict, but got "
|
||||
<< args_spec_list[index]->ToString();
|
||||
<< args_abs_list[index]->ToString();
|
||||
}
|
||||
}
|
||||
// Add to order list to trace if fn_node had side effect.
|
||||
ret_graph->set_output(ret_graph->NewCNodeInOrder(elems));
|
||||
return ret_graph;
|
||||
res_graph->set_output(res_graph->NewCNodeInOrder(elems));
|
||||
return res_graph;
|
||||
}
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -2023,17 +2023,33 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
// Process the object.
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
TraceGuard guard(std::make_shared<TraceResolve>(out_conf->node()->debug_info()));
|
||||
ValuePtr converted_ret = nullptr;
|
||||
bool converted = parse::ConvertData(obj, &converted_ret, true);
|
||||
ValuePtr converted_res = nullptr;
|
||||
bool converted = parse::ConvertData(obj, &converted_res, true);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Convert the python object failed";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(converted_ret);
|
||||
if (converted_ret->isa<FuncGraph>()) {
|
||||
AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
|
||||
MS_EXCEPTION_IF_NULL(converted_res);
|
||||
|
||||
// To check isolated side effect for the func graph who returns constant.
|
||||
if (engine->check_isolated_side_effect()) {
|
||||
MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", converted_res: " << converted_res->ToString();
|
||||
auto prim = GetValueWithoutDoSignature(converted_res)->cast<PrimitivePtr>();
|
||||
if (prim != nullptr) {
|
||||
auto effect_info = GetPrimEffectInfo(prim);
|
||||
if (effect_info.memory || effect_info.io) {
|
||||
MS_LOG(INFO) << "Found Side Effect Primitive CNode: " << out_conf->node()->DebugString();
|
||||
const auto &cnode = dyn_cast<CNode>(out_conf->node());
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_has_isolated_side_effect_node(true);
|
||||
out_conf->func_graph()->set_has_isolated_side_effect_node(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
|
||||
if (converted_res->isa<FuncGraph>()) {
|
||||
AddToManager(engine, converted_res->cast<FuncGraphPtr>());
|
||||
}
|
||||
AbstractBasePtr ret = ToAbstract(converted_res, AnalysisContext::DummyContext(), out_conf);
|
||||
auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
|
||||
return infer_result;
|
||||
|
|
|
@ -1863,3 +1863,30 @@ def test_print_assign_print():
|
|||
'param_3:\nTensor(shape=[], dtype=Int32, value=3)\n\n'}
|
||||
check_output(cap.output, patterns)
|
||||
np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_print_in_constant_returned_func():
|
||||
"""
|
||||
Feature: Auto Monad
|
||||
Description: Test print in a func graph who returns constant.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Print(Cell):
|
||||
def construct(self):
|
||||
x = tuple((1, 2, 3, 4, 5))
|
||||
print("x:", x)
|
||||
return x
|
||||
|
||||
cap = Capture()
|
||||
with capture(cap):
|
||||
net = Print()
|
||||
net()
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.1)
|
||||
|
||||
patterns = {'x:\n(1, 2, 3, 4, 5)'}
|
||||
check_output(cap.output, patterns)
|
||||
|
|
Loading…
Reference in New Issue