forked from mindspore-Ecosystem/mindspore
!1891 New control sink support resnet50
Merge pull request !1891 from zhoufeng/new-control-sink-resnet50
This commit is contained in:
commit
ffb5339e87
|
@ -33,11 +33,9 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) {
|
|||
if (node->size() <= kLabelGotoLabelId) {
|
||||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size();
|
||||
}
|
||||
auto label_set = AnfAlgo::GetCNodePrimitive(node->input(kLabelGotoLabelId));
|
||||
MS_EXCEPTION_IF_NULL(label_set);
|
||||
auto value = label_set->GetAttr(kAttrLabelIndex);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
uint32_t goto_label_id = GetValue<uint32_t>(value);
|
||||
|
||||
auto input = node->input(kLabelGotoLabelId);
|
||||
uint32_t goto_label_id = AnfAlgo::GetNodeAttr<uint32_t>(input, kAttrLabelIndex);
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get());
|
||||
MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id;
|
||||
node->set_inputs({node->input(0)});
|
||||
|
@ -57,11 +55,7 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
|
|||
break;
|
||||
}
|
||||
|
||||
auto label_set = AnfAlgo::GetCNodePrimitive(input);
|
||||
MS_EXCEPTION_IF_NULL(label_set);
|
||||
auto value = label_set->GetAttr(kAttrLabelIndex);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
uint32_t goto_label_id = GetValue<uint32_t>(value);
|
||||
uint32_t goto_label_id = AnfAlgo::GetNodeAttr<uint32_t>(input, kAttrLabelIndex);
|
||||
label_list.push_back(goto_label_id);
|
||||
MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id;
|
||||
}
|
||||
|
@ -154,7 +148,7 @@ uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> gr
|
|||
std::lock_guard<std::mutex> lock(label_num_mutex_);
|
||||
auto iter = label_num_.find(graph.get());
|
||||
if (iter == label_num_.end()) {
|
||||
MS_LOG(WARNING) << "Graph " << graph->ToString() << " has not assigned label.";
|
||||
MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 1.";
|
||||
return 1;
|
||||
}
|
||||
return iter->second;
|
||||
|
|
|
@ -33,31 +33,6 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
|
|||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
|
||||
for (auto &iter : graph_id_map) {
|
||||
auto &kg = iter.second;
|
||||
MS_EXCEPTION_IF_NULL(kg);
|
||||
auto real_inputs = kg->real_inputs();
|
||||
for (auto &it : real_inputs) {
|
||||
auto ¶meter = it.first;
|
||||
auto &args = it.second;
|
||||
for (auto &arg : args) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (arg->isa<Parameter>()) {
|
||||
MS_LOG(INFO) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
|
||||
<< ", arg:" << arg->DebugString();
|
||||
continue;
|
||||
}
|
||||
auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get()));
|
||||
if (target_graph_iter == graph_id_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
|
||||
}
|
||||
InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
if (memo->find(kg.get()) != memo->end()) {
|
||||
|
@ -89,6 +64,7 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union
|
|||
return;
|
||||
}
|
||||
memo->insert(kg.get());
|
||||
|
||||
const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs = kg->real_inputs();
|
||||
for (auto &iter : real_inputs) {
|
||||
auto ¶ = iter.first;
|
||||
|
@ -150,11 +126,10 @@ static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet
|
|||
const auto &root_inputs_vector = root_kg->inputs();
|
||||
root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end());
|
||||
for (auto &node : parameter_reuse_set) {
|
||||
if (root_inputs_set.find(node) == root_inputs_set.end()) {
|
||||
continue;
|
||||
if (root_inputs_set.find(node) != root_inputs_set.end()) {
|
||||
main_parameter = node;
|
||||
break;
|
||||
}
|
||||
|
||||
main_parameter = node;
|
||||
}
|
||||
|
||||
std::set<KernelGraphPtr> memo;
|
||||
|
@ -162,9 +137,18 @@ static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet
|
|||
}
|
||||
}
|
||||
|
||||
CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
|
||||
for (size_t i = start; i < list.size() - 1; ++i) {
|
||||
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) {
|
||||
return list[i];
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
|
||||
std::set<KernelGraphPtr> memo;
|
||||
ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
|
||||
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
|
||||
std::map<uint32_t, KernelGraphPtr> graph_id_map;
|
||||
for (auto &g : memo) {
|
||||
if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) {
|
||||
|
@ -181,13 +165,34 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
|
|||
ChildGraphDataAssign(graph_id_map);
|
||||
}
|
||||
|
||||
CNodePtr AscendControlParser::GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
|
||||
for (size_t i = start; i < list.size() - 1; ++i) {
|
||||
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) {
|
||||
return list[i];
|
||||
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
|
||||
std::set<KernelGraphPtr> memo;
|
||||
(void)RecurseGraph(root_graph, NOT_NULL(&memo));
|
||||
}
|
||||
|
||||
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
|
||||
for (auto &iter : graph_id_map) {
|
||||
auto &kg = iter.second;
|
||||
MS_EXCEPTION_IF_NULL(kg);
|
||||
auto real_inputs = kg->real_inputs();
|
||||
for (auto &it : real_inputs) {
|
||||
auto ¶meter = it.first;
|
||||
auto &args = it.second;
|
||||
for (auto &arg : args) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (arg->isa<Parameter>()) {
|
||||
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
|
||||
<< ", arg:" << arg->DebugString();
|
||||
continue;
|
||||
}
|
||||
auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get()));
|
||||
if (target_graph_iter == graph_id_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
|
||||
}
|
||||
InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
|
||||
|
@ -212,9 +217,16 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
|
|||
MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!";
|
||||
}
|
||||
// 4. insert first_label
|
||||
auto start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
|
||||
MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString();
|
||||
kg->set_start_label(start_label);
|
||||
CNodePtr start_label;
|
||||
if (last_node != nullptr && last_label != nullptr) {
|
||||
start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
|
||||
MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString();
|
||||
kg->set_start_label(start_label);
|
||||
} else {
|
||||
// no goto node will jump to start label of root graph, so return a fake label
|
||||
start_label = std::make_shared<CNode>(std::vector<AnfNodePtr>(), FuncGraphPtr(nullptr));
|
||||
}
|
||||
|
||||
// 5. traverse
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
auto &cnode = nodes[i];
|
||||
|
@ -249,11 +261,10 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
|
|||
}
|
||||
|
||||
void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) {
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))};
|
||||
auto return_node = kg->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
inputs.push_back(return_node->input(1));
|
||||
inputs.push_back(attch_node.get());
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
return_node->input(1), attch_node.get()};
|
||||
auto depend_node = kg->NewCNode(inputs);
|
||||
return_node->set_input(1, depend_node);
|
||||
}
|
||||
|
@ -407,9 +418,9 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
|
|||
if (partial_cnode->size() < kCNodePartialLength) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength;
|
||||
}
|
||||
|
||||
auto partial_inputs = partial_cnode->inputs();
|
||||
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
|
||||
|
||||
return {partial_cnode, branch_kg};
|
||||
}
|
||||
|
||||
|
@ -425,7 +436,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
|
|||
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
|
||||
<< to->DebugString();
|
||||
// config inputs of assign node
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from};
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimAssign->name())), to, from};
|
||||
// generate a new cnode
|
||||
auto assign_node = kg->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(assign_node);
|
||||
|
@ -434,11 +445,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
|
|||
InsertDependToGraph(kg, NOT_NULL(assign_node));
|
||||
}
|
||||
|
||||
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
|
||||
std::set<KernelGraphPtr> memo;
|
||||
(void)RecurseGraph(root_graph, NOT_NULL(&memo));
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start";
|
||||
|
@ -457,29 +463,24 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
|
|||
if (node == graph->get_end_goto()) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
|
||||
if (!CheckLabelIndex(child_order_index, 0, node, graph)) {
|
||||
MS_LOG(EXCEPTION) << "Check label index fail";
|
||||
}
|
||||
auto child_graph = graph->child_graph_order()[child_order_index++];
|
||||
if (child_graph == graph->parent_graph()) {
|
||||
continue;
|
||||
}
|
||||
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
|
||||
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
|
||||
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
|
||||
std::vector<uint32_t> label_switch_list = GetLabelSwitchList(node);
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
|
||||
std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
|
||||
for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) {
|
||||
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
|
||||
MS_LOG(EXCEPTION) << "Check label index fail";
|
||||
}
|
||||
auto child_graph = graph->child_graph_order()[child_order_index++];
|
||||
if (child_graph == graph->parent_graph()) {
|
||||
continue;
|
||||
}
|
||||
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
|
||||
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
|
||||
uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
if (!CheckLabelIndex(child_order_index, label_index, node, graph)) {
|
||||
MS_LOG(EXCEPTION) << "Check label index fail";
|
||||
}
|
||||
auto child_graph = graph->child_graph_order()[child_order_index++];
|
||||
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
|
||||
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
|
||||
}
|
||||
}
|
||||
graph->set_execution_order(execution_order);
|
||||
|
@ -487,15 +488,6 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
|
|||
return execution_order;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> AscendControlParser::GetLabelSwitchList(const CNodePtr &node) {
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
|
||||
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
|
||||
}
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
|
||||
}
|
||||
|
||||
bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
|
||||
NotNull<KernelGraphPtr> graph) {
|
||||
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
|
||||
|
@ -504,33 +496,19 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
|
|||
MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size "
|
||||
<< child_graph_order.size() << " goto index " << order_index;
|
||||
}
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) {
|
||||
// check label_goto and start_label in child graph
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_label)) {
|
||||
MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index";
|
||||
}
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cur_label);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
uint32_t label_goto_index = GetValue<uint32_t>(primitive->GetAttr(kAttrLabelIndex));
|
||||
label_index = label_goto_index;
|
||||
}
|
||||
// get start_label_set_index of child graph
|
||||
auto child_graph = child_graph_order[order_index];
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
|
||||
// get start_label_set_index of child graph
|
||||
auto start_label_set = child_graph->get_start_label();
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) {
|
||||
MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index";
|
||||
}
|
||||
auto start_primitive = AnfAlgo::GetCNodePrimitive(start_label_set);
|
||||
MS_EXCEPTION_IF_NULL(start_primitive);
|
||||
uint32_t start_label_set_index = GetValue<uint32_t>(start_primitive->GetAttr(kAttrLabelIndex));
|
||||
uint32_t start_label_set_index = AnfAlgo::GetNodeAttr<uint32_t>(start_label_set, kAttrLabelIndex);
|
||||
if (label_index != start_label_set_index) {
|
||||
MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString()
|
||||
<< " index " << start_label_set_index << " current child graph order : " << order_index;
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
|
||||
|
|
|
@ -54,10 +54,7 @@ class AscendControlParser {
|
|||
|
||||
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
|
||||
static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start);
|
||||
|
||||
// root graph order
|
||||
static std::vector<uint32_t> GetLabelSwitchList(const CNodePtr &node);
|
||||
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
|
||||
NotNull<KernelGraphPtr> graph);
|
||||
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
|
||||
|
|
|
@ -377,8 +377,8 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
|
|||
MS_EXCEPTION_IF_NULL(old_backend_anf);
|
||||
MS_EXCEPTION_IF_NULL(new_backend_anf);
|
||||
if (old_backend_anf == new_backend_anf) {
|
||||
MS_LOG(INFO) << "old:" << old_backend_anf->DebugString() << ",new:" << new_backend_anf->DebugString();
|
||||
MS_LOG(EXCEPTION) << "old can't be same with new";
|
||||
MS_LOG(DEBUG) << "old same with new:" << old_backend_anf->DebugString();
|
||||
return;
|
||||
}
|
||||
if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
|
||||
MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
|
||||
|
|
|
@ -620,12 +620,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
// if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
|
||||
graph->set_output_null(is_trace_back);
|
||||
AddParameterToGraphInputs(func_graph->parameters(), graph.get());
|
||||
MS_EXCEPTION_IF_NULL(context_);
|
||||
FuncGraphManagerPtr manager = MakeManager({graph});
|
||||
if (manager) {
|
||||
manager->AddFuncGraph(graph);
|
||||
graph->set_manager(manager);
|
||||
}
|
||||
graph->SetExecOrderByDefault();
|
||||
if (ExistSummaryNode(graph.get())) {
|
||||
graph->set_summary_node_exist(true);
|
||||
|
|
Loading…
Reference in New Issue