forked from OSSInnovation/mindspore
!4030 replace unused parameter in graph inputs
Merge pull request !4030 from laiyongqiang/replace_parameter
This commit is contained in:
commit
c55c0e0f0c
|
@ -261,13 +261,14 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
|
|||
}
|
||||
}
|
||||
|
||||
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph);
|
||||
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph,
|
||||
graph_list);
|
||||
}
|
||||
|
||||
void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count,
|
||||
const std::set<CNodePtr> &all_nodes,
|
||||
const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node,
|
||||
NotNull<KernelGraphPtr> root_graph) {
|
||||
NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list) {
|
||||
std::vector<CNodePtr> exec_order = root_graph->execution_order();
|
||||
while (parameter_count->HasValidElem()) {
|
||||
auto [para, read, written] = parameter_count->GetOneValidElem();
|
||||
|
@ -292,6 +293,8 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete
|
|||
if (visit_source->isa<Parameter>()) {
|
||||
parameter_count->AddReadCount(visit_source, read - 1);
|
||||
}
|
||||
|
||||
// replace parameter in node
|
||||
for (auto &node : all_nodes) {
|
||||
for (size_t i = 0; i < node->size(); ++i) {
|
||||
if (node->input(i) == para) {
|
||||
|
@ -300,6 +303,14 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// replace parameter in graph input
|
||||
for (auto &g : graph_list) {
|
||||
auto child_graph_inputs = g->MutableInputs();
|
||||
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), para, source);
|
||||
MS_LOG_INFO << "Replace parameter " << para->DebugString() << " by " << source->DebugString() << " in graph "
|
||||
<< g->graph_id() << " inputs";
|
||||
}
|
||||
}
|
||||
root_graph->set_execution_order(exec_order);
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ class AscendControlParser {
|
|||
static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
|
||||
static void EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, const std::set<CNodePtr> &all_nodes,
|
||||
const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node,
|
||||
NotNull<KernelGraphPtr> root_graph);
|
||||
NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
|
||||
static void EraseLabel(NotNull<KernelGraphPtr> root_graph);
|
||||
static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg,
|
||||
const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,
|
||||
|
|
|
@ -153,9 +153,6 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
|
||||
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
|
||||
UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
// add make_tuple to the output graph
|
||||
|
@ -178,7 +175,10 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
debugger_->PreExecute(root_graph);
|
||||
}
|
||||
SetSummaryNodes(root_graph.get());
|
||||
// alloc mem
|
||||
// Alloc memory for child graph's inputs
|
||||
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
// Alloc memory for root graph's inputs and node's outputs, workspace
|
||||
MemoryAlloc(root_graph.get());
|
||||
// generate and load task into device
|
||||
Load(root_graph);
|
||||
|
|
|
@ -337,6 +337,8 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
|
||||
}
|
||||
MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope()
|
||||
<< " index: " << index << " size: " << tensor_size;
|
||||
AnfAlgo::SetOutputAddr(address, index, item.get());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue