forked from mindspore-Ecosystem/mindspore
!3946 Ignore create parameter from control depend inputs
Merge pull request !3946 from YuJianfeng/master
This commit is contained in:
commit
34214e8f4c
|
@ -288,6 +288,22 @@ bool ExistSummaryNode(const KernelGraph *graph) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &node_inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
GraphId SessionBasic::graph_sum_ = 0;
|
||||
|
@ -354,8 +370,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
std::vector<AnfNodePtr> pre_graph_out = {node};
|
||||
if (IgnoreCreateParameterForMakeTuple(node)) {
|
||||
pre_graph_out.clear();
|
||||
}
|
||||
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
|
||||
if (!AnfAlgo::IsRealKernel(node)) {
|
||||
if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
|
||||
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
|
||||
}
|
||||
auto valid_inputs = graph->MutableValidInputs();
|
||||
|
@ -431,7 +450,8 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool
|
|||
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
||||
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
|
||||
if (parameters.empty()) {
|
||||
MS_LOG(EXCEPTION) << "No parameter exist!!";
|
||||
MS_LOG(INFO) << "Empty parameter from cnode";
|
||||
return nullptr;
|
||||
}
|
||||
if (parameters.size() == 1) {
|
||||
return parameters[0];
|
||||
|
@ -505,11 +525,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
|
|||
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
|
||||
continue;
|
||||
} else if (optimize_control_depend) {
|
||||
cnode_inputs.push_back(NewValueNode(MakeValue(input_idx)));
|
||||
cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
|
||||
} else {
|
||||
*from_other_graph = true;
|
||||
// the input node is a cnode from other graph
|
||||
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
|
||||
if (parameter_from_cnode == nullptr) {
|
||||
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx)));
|
||||
}
|
||||
cnode_inputs.push_back(parameter_from_cnode);
|
||||
(*other_graph_cnode)[anf] = parameter_from_cnode;
|
||||
}
|
||||
|
@ -878,7 +901,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
auto tensor = inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto input_node = input_nodes[i];
|
||||
if (TensorNeedSync(input_node, tensor) && input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
if (ms_context->execution_mode() == kPynativeMode ||
|
||||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
|
||||
|
|
|
@ -79,6 +79,42 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, c
|
|||
return output;
|
||||
}
|
||||
|
||||
namespace {
|
||||
AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr,
|
||||
AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
MS_EXCEPTION_IF_NULL(inputs_ptr);
|
||||
MS_EXCEPTION_IF_NULL(eqv_ptr);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto &inputs = *inputs_ptr;
|
||||
auto &eqv = *eqv_ptr;
|
||||
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
|
||||
eqv[node] = node;
|
||||
} else if (eqv.find(node) == eqv.end()) {
|
||||
bool ignore_make_tuple = false;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
ignore_make_tuple = true;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &node_inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
||||
if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) {
|
||||
ignore_make_tuple = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!ignore_make_tuple) {
|
||||
inputs.push_back(node);
|
||||
}
|
||||
eqv[node] = fg->add_parameter();
|
||||
eqv[node]->set_abstract(node->abstract());
|
||||
eqv[node]->set_kernel_info(node->kernel_info_ptr());
|
||||
}
|
||||
return eqv[node];
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
AnfNodePtrList inputs;
|
||||
|
@ -86,17 +122,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
if (lst.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Input anf node list is empty";
|
||||
}
|
||||
auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr {
|
||||
if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) {
|
||||
eqv[a] = a;
|
||||
} else if (eqv.find(a) == eqv.end()) {
|
||||
inputs.push_back(a);
|
||||
eqv[a] = fg->add_parameter();
|
||||
eqv[a]->set_abstract(a->abstract());
|
||||
eqv[a]->set_kernel_info(a->kernel_info_ptr());
|
||||
}
|
||||
return eqv[a];
|
||||
};
|
||||
// Merge CNodes into a AnfGraph that represents a linear instruction segment
|
||||
for (auto n : lst) {
|
||||
if (!n->isa<CNode>()) {
|
||||
|
@ -122,11 +147,12 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) {
|
||||
args.emplace_back(NewValueNode(MakeValue(i)));
|
||||
} else {
|
||||
args.emplace_back(ref(inps[i]));
|
||||
args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref);
|
||||
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
|
||||
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
|
||||
}
|
||||
eqv[n] = fg->NewCNode(args);
|
||||
eqv[n]->set_abstract(n->abstract());
|
||||
|
|
Loading…
Reference in New Issue