ensure real_input order

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2020-06-11 19:57:35 +08:00
parent 9b4612f801
commit 821836a00f
4 changed files with 48 additions and 28 deletions

View File

@ -148,8 +148,8 @@ uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> gr
std::lock_guard<std::mutex> lock(label_num_mutex_); std::lock_guard<std::mutex> lock(label_num_mutex_);
auto iter = label_num_.find(graph.get()); auto iter = label_num_.find(graph.get());
if (iter == label_num_.end()) { if (iter == label_num_.end()) {
MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 1."; MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 0.";
return 1; return 0;
} }
return iter->second; return iter->second;
} }

View File

@ -40,7 +40,7 @@ static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFind
} }
memo->insert(kg.get()); memo->insert(kg.get());
const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs = kg->real_inputs(); const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs = kg->real_inputs();
for (auto &iter : real_inputs) { for (auto &iter : real_inputs) {
auto &para = iter.first; auto &para = iter.first;
if (para->isa<Parameter>()) { if (para->isa<Parameter>()) {
@ -65,7 +65,7 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union
} }
memo->insert(kg.get()); memo->insert(kg.get());
const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs = kg->real_inputs(); const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs = kg->real_inputs();
for (auto &iter : real_inputs) { for (auto &iter : real_inputs) {
auto &para = iter.first; auto &para = iter.first;
for (auto &arg : iter.second) { for (auto &arg : iter.second) {
@ -174,10 +174,14 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
for (auto &iter : graph_id_map) { for (auto &iter : graph_id_map) {
auto &kg = iter.second; auto &kg = iter.second;
MS_EXCEPTION_IF_NULL(kg); MS_EXCEPTION_IF_NULL(kg);
auto real_inputs = kg->real_inputs(); const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs = kg->real_inputs();
for (auto &it : real_inputs) { for (auto &in : kg->inputs()) {
auto &parameter = it.first; auto it = real_inputs.find(in);
auto &args = it.second; if (it == real_inputs.end()) {
continue;
}
auto &parameter = it->first;
auto &args = it->second;
for (auto &arg : args) { for (auto &arg : args) {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<Parameter>()) { if (arg->isa<Parameter>()) {

View File

@ -677,13 +677,13 @@ void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &ar
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
if (real_inputs_.find(parameter) == real_inputs_.end()) { if (real_inputs_.find(parameter) == real_inputs_.end()) {
real_inputs_[parameter] = std::set<AnfNodePtr>(); real_inputs_[parameter] = std::vector<AnfNodePtr>();
} }
auto &args = real_inputs_[parameter]; auto &args = real_inputs_[parameter];
(void)args.insert(arg); (void)args.push_back(arg);
} }
std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) { std::vector<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) {
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
auto iter = real_inputs_.find(parameter); auto iter = real_inputs_.find(parameter);
if (iter != real_inputs_.end()) { if (iter != real_inputs_.end()) {
@ -694,7 +694,7 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) {
void KernelGraph::UpdateCallRealInput() { void KernelGraph::UpdateCallRealInput() {
MS_LOG(INFO) << "Update graph id: " << graph_id_; MS_LOG(INFO) << "Update graph id: " << graph_id_;
std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_map; std::map<AnfNodePtr, std::vector<AnfNodePtr>> real_inputs_map;
for (auto &it : real_inputs_) { for (auto &it : real_inputs_) {
auto parameter = it.first; auto parameter = it.first;
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
@ -713,12 +713,18 @@ void KernelGraph::UpdateCallRealInput() {
} }
for (auto &erase_node : erase_real_inputs) { for (auto &erase_node : erase_real_inputs) {
MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString(); MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString();
(void)real_inputs.erase(erase_node); for (auto iter = real_inputs.begin(); iter != real_inputs.end();) {
if (*iter == erase_node) {
iter = real_inputs.erase(iter);
} else {
++iter;
}
}
} }
for (auto &new_real_input : new_real_inputs) { for (auto &new_real_input : new_real_inputs) {
MS_LOG(INFO) << "paramter: " << parameter->DebugString() MS_LOG(INFO) << "paramter: " << parameter->DebugString()
<< " insert real input:" << new_real_input->DebugString(); << " insert real input:" << new_real_input->DebugString();
(void)real_inputs.insert(new_real_input); (void)real_inputs.push_back(new_real_input);
} }
real_inputs_map[parameter] = real_inputs; real_inputs_map[parameter] = real_inputs;
} }
@ -730,18 +736,28 @@ void KernelGraph::PrintGraphExecuteOrder() const {
for (size_t i = 0; i < execution_order_.size(); i++) { for (size_t i = 0; i < execution_order_.size(); i++) {
CNodePtr cur_cnode_ptr = execution_order_[i]; CNodePtr cur_cnode_ptr = execution_order_[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr); MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) { std::string event_str;
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); std::string label_str;
MS_LOG(INFO) << "index[" << i << "], node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id[" if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) {
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" event_str = ", event_id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId)) + "]";
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], event_id["
<< GetValue<uint32_t>(primitive->GetAttr(kAttrEventId)) << "], node info["
<< cur_cnode_ptr->DebugString() << "]";
} else {
MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]";
} }
if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) {
label_str = ", label_id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrLabelIndex)) + "]";
}
if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) {
auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrLabelSwitchList);
label_str = ", label_id[";
for (size_t j = 0; j < label_list.size(); ++j) {
label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]");
}
}
MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"
<< event_str << label_str;
} }
} }

View File

@ -127,8 +127,8 @@ class KernelGraph : public FuncGraph {
// find anf node in graph // find anf node in graph
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const; std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
// get real inputs // get real inputs
const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs() const { return real_inputs_; } const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs() const { return real_inputs_; }
std::set<AnfNodePtr> GetRealInput(const AnfNodePtr &parameter); std::vector<AnfNodePtr> GetRealInput(const AnfNodePtr &parameter);
void SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg); void SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg);
// used to dump ir // used to dump ir
std::string ToString() const override; std::string ToString() const override;
@ -194,7 +194,7 @@ class KernelGraph : public FuncGraph {
// parameter graph // parameter graph
std::shared_ptr<KernelGraph> parent_graph_; std::shared_ptr<KernelGraph> parent_graph_;
// record real parameters,inputs_ is the formal parameters // record real parameters,inputs_ is the formal parameters
std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_; std::map<AnfNodePtr, std::vector<AnfNodePtr>> real_inputs_;
CNodePtr start_label_; CNodePtr start_label_;
CNodePtr end_goto_; CNodePtr end_goto_;