!1891 New control sink support resnet50

Merge pull request !1891 from zhoufeng/new-control-sink-resnet50
This commit is contained in:
mindspore-ci-bot 2020-06-08 19:08:11 +08:00 committed by Gitee
commit ffb5339e87
5 changed files with 76 additions and 113 deletions

View File

@ -33,11 +33,9 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) {
if (node->size() <= kLabelGotoLabelId) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size();
}
auto label_set = AnfAlgo::GetCNodePrimitive(node->input(kLabelGotoLabelId));
MS_EXCEPTION_IF_NULL(label_set);
auto value = label_set->GetAttr(kAttrLabelIndex);
MS_EXCEPTION_IF_NULL(value);
uint32_t goto_label_id = GetValue<uint32_t>(value);
auto input = node->input(kLabelGotoLabelId);
uint32_t goto_label_id = AnfAlgo::GetNodeAttr<uint32_t>(input, kAttrLabelIndex);
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get());
MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id;
node->set_inputs({node->input(0)});
@ -57,11 +55,7 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
break;
}
auto label_set = AnfAlgo::GetCNodePrimitive(input);
MS_EXCEPTION_IF_NULL(label_set);
auto value = label_set->GetAttr(kAttrLabelIndex);
MS_EXCEPTION_IF_NULL(value);
uint32_t goto_label_id = GetValue<uint32_t>(value);
uint32_t goto_label_id = AnfAlgo::GetNodeAttr<uint32_t>(input, kAttrLabelIndex);
label_list.push_back(goto_label_id);
MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id;
}
@ -154,7 +148,7 @@ uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> gr
std::lock_guard<std::mutex> lock(label_num_mutex_);
auto iter = label_num_.find(graph.get());
if (iter == label_num_.end()) {
MS_LOG(WARNING) << "Graph " << graph->ToString() << " has not assigned label.";
MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 1.";
return 1;
}
return iter->second;

View File

@ -33,31 +33,6 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
namespace mindspore {
namespace session {
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
for (auto &iter : graph_id_map) {
auto &kg = iter.second;
MS_EXCEPTION_IF_NULL(kg);
auto real_inputs = kg->real_inputs();
for (auto &it : real_inputs) {
auto &parameter = it.first;
auto &args = it.second;
for (auto &arg : args) {
MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<Parameter>()) {
MS_LOG(INFO) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
<< ", arg:" << arg->DebugString();
continue;
}
auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get()));
if (target_graph_iter == graph_id_map.end()) {
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
}
InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
}
}
}
}
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg.get()) != memo->end()) {
@ -89,6 +64,7 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union
return;
}
memo->insert(kg.get());
const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs = kg->real_inputs();
for (auto &iter : real_inputs) {
auto &para = iter.first;
@ -150,11 +126,10 @@ static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet
const auto &root_inputs_vector = root_kg->inputs();
root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end());
for (auto &node : parameter_reuse_set) {
if (root_inputs_set.find(node) == root_inputs_set.end()) {
continue;
if (root_inputs_set.find(node) != root_inputs_set.end()) {
main_parameter = node;
break;
}
main_parameter = node;
}
std::set<KernelGraphPtr> memo;
@ -162,9 +137,18 @@ static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet
}
}
CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
for (size_t i = start; i < list.size() - 1; ++i) {
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) {
return list[i];
}
}
return nullptr;
}
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
std::set<KernelGraphPtr> memo;
ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
std::map<uint32_t, KernelGraphPtr> graph_id_map;
for (auto &g : memo) {
if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) {
@ -181,13 +165,34 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
ChildGraphDataAssign(graph_id_map);
}
CNodePtr AscendControlParser::GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
for (size_t i = start; i < list.size() - 1; ++i) {
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) {
return list[i];
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
std::set<KernelGraphPtr> memo;
(void)RecurseGraph(root_graph, NOT_NULL(&memo));
}
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
for (auto &iter : graph_id_map) {
auto &kg = iter.second;
MS_EXCEPTION_IF_NULL(kg);
auto real_inputs = kg->real_inputs();
for (auto &it : real_inputs) {
auto &parameter = it.first;
auto &args = it.second;
for (auto &arg : args) {
MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<Parameter>()) {
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
<< ", arg:" << arg->DebugString();
continue;
}
auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get()));
if (target_graph_iter == graph_id_map.end()) {
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
}
InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
}
}
}
return nullptr;
}
NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
@ -212,9 +217,16 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!";
}
// 4. insert first_label
auto start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString();
kg->set_start_label(start_label);
CNodePtr start_label;
if (last_node != nullptr && last_label != nullptr) {
start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString();
kg->set_start_label(start_label);
} else {
// no goto node will jump to start label of root graph, so return a fake label
start_label = std::make_shared<CNode>(std::vector<AnfNodePtr>(), FuncGraphPtr(nullptr));
}
// 5. traverse
for (size_t i = 0; i < nodes.size(); ++i) {
auto &cnode = nodes[i];
@ -249,11 +261,10 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
}
void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) {
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))};
auto return_node = kg->get_return();
MS_EXCEPTION_IF_NULL(return_node);
inputs.push_back(return_node->input(1));
inputs.push_back(attch_node.get());
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
return_node->input(1), attch_node.get()};
auto depend_node = kg->NewCNode(inputs);
return_node->set_input(1, depend_node);
}
@ -407,9 +418,9 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
if (partial_cnode->size() < kCNodePartialLength) {
MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength;
}
auto partial_inputs = partial_cnode->inputs();
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
return {partial_cnode, branch_kg};
}
@ -425,7 +436,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
<< to->DebugString();
// config inputs of assign node
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from};
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimAssign->name())), to, from};
// generate a new cnode
auto assign_node = kg->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(assign_node);
@ -434,11 +445,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
InsertDependToGraph(kg, NOT_NULL(assign_node));
}
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
std::set<KernelGraphPtr> memo;
(void)RecurseGraph(root_graph, NOT_NULL(&memo));
}
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start";
@ -457,29 +463,24 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
if (node == graph->get_end_goto()) {
continue;
}
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
if (!CheckLabelIndex(child_order_index, 0, node, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
auto child_graph = graph->child_graph_order()[child_order_index++];
if (child_graph == graph->parent_graph()) {
continue;
}
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
std::vector<uint32_t> label_switch_list = GetLabelSwitchList(node);
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) {
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
auto child_graph = graph->child_graph_order()[child_order_index++];
if (child_graph == graph->parent_graph()) {
continue;
}
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
}
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
if (!CheckLabelIndex(child_order_index, label_index, node, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
auto child_graph = graph->child_graph_order()[child_order_index++];
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
}
}
graph->set_execution_order(execution_order);
@ -487,15 +488,6 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
return execution_order;
}
std::vector<uint32_t> AscendControlParser::GetLabelSwitchList(const CNodePtr &node) {
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
}
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
return GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
}
bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
NotNull<KernelGraphPtr> graph) {
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
@ -504,33 +496,19 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size "
<< child_graph_order.size() << " goto index " << order_index;
}
if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) {
// check label_goto and start_label in child graph
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_label)) {
MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index";
}
auto primitive = AnfAlgo::GetCNodePrimitive(cur_label);
MS_EXCEPTION_IF_NULL(primitive);
uint32_t label_goto_index = GetValue<uint32_t>(primitive->GetAttr(kAttrLabelIndex));
label_index = label_goto_index;
}
// get start_label_set_index of child graph
auto child_graph = child_graph_order[order_index];
MS_EXCEPTION_IF_NULL(child_graph);
// get start_label_set_index of child graph
auto start_label_set = child_graph->get_start_label();
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) {
MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index";
}
auto start_primitive = AnfAlgo::GetCNodePrimitive(start_label_set);
MS_EXCEPTION_IF_NULL(start_primitive);
uint32_t start_label_set_index = GetValue<uint32_t>(start_primitive->GetAttr(kAttrLabelIndex));
uint32_t start_label_set_index = AnfAlgo::GetNodeAttr<uint32_t>(start_label_set, kAttrLabelIndex);
if (label_index != start_label_set_index) {
MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString()
<< " index " << start_label_set_index << " current child graph order : " << order_index;
return false;
} else {
return true;
}
return true;
}
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {

View File

@ -54,10 +54,7 @@ class AscendControlParser {
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start);
// root graph order
static std::vector<uint32_t> GetLabelSwitchList(const CNodePtr &node);
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
NotNull<KernelGraphPtr> graph);
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,

View File

@ -377,8 +377,8 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
MS_EXCEPTION_IF_NULL(old_backend_anf);
MS_EXCEPTION_IF_NULL(new_backend_anf);
if (old_backend_anf == new_backend_anf) {
MS_LOG(INFO) << "old:" << old_backend_anf->DebugString() << ",new:" << new_backend_anf->DebugString();
MS_LOG(EXCEPTION) << "old can't be same with new";
MS_LOG(DEBUG) << "old same with new:" << old_backend_anf->DebugString();
return;
}
if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";

View File

@ -620,12 +620,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
// if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
graph->set_output_null(is_trace_back);
AddParameterToGraphInputs(func_graph->parameters(), graph.get());
MS_EXCEPTION_IF_NULL(context_);
FuncGraphManagerPtr manager = MakeManager({graph});
if (manager) {
manager->AddFuncGraph(graph);
graph->set_manager(manager);
}
graph->SetExecOrderByDefault();
if (ExistSummaryNode(graph.get())) {
graph->set_summary_node_exist(true);