!4030 replace unused parameter in graph inputs

Merge pull request !4030 from laiyongqiang/replace_parameter
This commit is contained in:
mindspore-ci-bot 2020-09-24 16:01:00 +08:00 committed by Gitee
commit c55c0e0f0c
4 changed files with 20 additions and 7 deletions

View File

@ -261,13 +261,14 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
} }
} }
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph); EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph,
graph_list);
} }
void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count,
const std::set<CNodePtr> &all_nodes, const std::set<CNodePtr> &all_nodes,
const std::map<AnfNodePtr, CNodePtr> &para_to_written_node, const std::map<AnfNodePtr, CNodePtr> &para_to_written_node,
NotNull<KernelGraphPtr> root_graph) { NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list) {
std::vector<CNodePtr> exec_order = root_graph->execution_order(); std::vector<CNodePtr> exec_order = root_graph->execution_order();
while (parameter_count->HasValidElem()) { while (parameter_count->HasValidElem()) {
auto [para, read, written] = parameter_count->GetOneValidElem(); auto [para, read, written] = parameter_count->GetOneValidElem();
@ -292,6 +293,8 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete
if (visit_source->isa<Parameter>()) { if (visit_source->isa<Parameter>()) {
parameter_count->AddReadCount(visit_source, read - 1); parameter_count->AddReadCount(visit_source, read - 1);
} }
// replace parameter in node
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
for (size_t i = 0; i < node->size(); ++i) { for (size_t i = 0; i < node->size(); ++i) {
if (node->input(i) == para) { if (node->input(i) == para) {
@ -300,6 +303,14 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete
} }
} }
} }
// replace parameter in graph input
for (auto &g : graph_list) {
auto child_graph_inputs = g->MutableInputs();
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), para, source);
MS_LOG_INFO << "Replace parameter " << para->DebugString() << " by " << source->DebugString() << " in graph "
<< g->graph_id() << " inputs";
}
} }
root_graph->set_execution_order(exec_order); root_graph->set_execution_order(exec_order);
} }

View File

@ -47,7 +47,7 @@ class AscendControlParser {
static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list); static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
static void EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, const std::set<CNodePtr> &all_nodes, static void EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, const std::set<CNodePtr> &all_nodes,
const std::map<AnfNodePtr, CNodePtr> &para_to_written_node, const std::map<AnfNodePtr, CNodePtr> &para_to_written_node,
NotNull<KernelGraphPtr> root_graph); NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
static void EraseLabel(NotNull<KernelGraphPtr> root_graph); static void EraseLabel(NotNull<KernelGraphPtr> root_graph);
static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg, static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg,
const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list, const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,

View File

@ -153,9 +153,6 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo)); HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear(); memo.clear();
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo)); UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear(); memo.clear();
// add make_tuple to the output graph // add make_tuple to the output graph
@ -178,7 +175,10 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
debugger_->PreExecute(root_graph); debugger_->PreExecute(root_graph);
} }
SetSummaryNodes(root_graph.get()); SetSummaryNodes(root_graph.get());
// alloc mem // Alloc memory for child graph's inputs
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
// Alloc memory for root graph's inputs and node's outputs, workspace
MemoryAlloc(root_graph.get()); MemoryAlloc(root_graph.get());
// generate and load task into device // generate and load task into device
Load(root_graph); Load(root_graph);

View File

@ -337,6 +337,8 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
} }
MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope()
<< " index: " << index << " size: " << tensor_size;
AnfAlgo::SetOutputAddr(address, index, item.get()); AnfAlgo::SetOutputAddr(address, index, item.get());
} }
} }