forked from OSSInnovation/mindspore
!4030 replace unused parameter in graph inputs
Merge pull request !4030 from laiyongqiang/replace_parameter
This commit is contained in:
commit
c55c0e0f0c
|
@ -261,13 +261,14 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph);
|
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph,
|
||||||
|
graph_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count,
|
void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count,
|
||||||
const std::set<CNodePtr> &all_nodes,
|
const std::set<CNodePtr> &all_nodes,
|
||||||
const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node,
|
const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node,
|
||||||
NotNull<KernelGraphPtr> root_graph) {
|
NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list) {
|
||||||
std::vector<CNodePtr> exec_order = root_graph->execution_order();
|
std::vector<CNodePtr> exec_order = root_graph->execution_order();
|
||||||
while (parameter_count->HasValidElem()) {
|
while (parameter_count->HasValidElem()) {
|
||||||
auto [para, read, written] = parameter_count->GetOneValidElem();
|
auto [para, read, written] = parameter_count->GetOneValidElem();
|
||||||
|
@ -292,6 +293,8 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete
|
||||||
if (visit_source->isa<Parameter>()) {
|
if (visit_source->isa<Parameter>()) {
|
||||||
parameter_count->AddReadCount(visit_source, read - 1);
|
parameter_count->AddReadCount(visit_source, read - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// replace parameter in node
|
||||||
for (auto &node : all_nodes) {
|
for (auto &node : all_nodes) {
|
||||||
for (size_t i = 0; i < node->size(); ++i) {
|
for (size_t i = 0; i < node->size(); ++i) {
|
||||||
if (node->input(i) == para) {
|
if (node->input(i) == para) {
|
||||||
|
@ -300,6 +303,14 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// replace parameter in graph input
|
||||||
|
for (auto &g : graph_list) {
|
||||||
|
auto child_graph_inputs = g->MutableInputs();
|
||||||
|
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), para, source);
|
||||||
|
MS_LOG_INFO << "Replace parameter " << para->DebugString() << " by " << source->DebugString() << " in graph "
|
||||||
|
<< g->graph_id() << " inputs";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
root_graph->set_execution_order(exec_order);
|
root_graph->set_execution_order(exec_order);
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ class AscendControlParser {
|
||||||
static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
|
static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
|
||||||
static void EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, const std::set<CNodePtr> &all_nodes,
|
static void EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, const std::set<CNodePtr> &all_nodes,
|
||||||
const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node,
|
const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node,
|
||||||
NotNull<KernelGraphPtr> root_graph);
|
NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
|
||||||
static void EraseLabel(NotNull<KernelGraphPtr> root_graph);
|
static void EraseLabel(NotNull<KernelGraphPtr> root_graph);
|
||||||
static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg,
|
static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg,
|
||||||
const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,
|
const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,
|
||||||
|
|
|
@ -153,9 +153,6 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
||||||
HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
|
HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||||
memo.clear();
|
memo.clear();
|
||||||
|
|
||||||
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
|
|
||||||
memo.clear();
|
|
||||||
|
|
||||||
UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
|
UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||||
memo.clear();
|
memo.clear();
|
||||||
// add make_tuple to the output graph
|
// add make_tuple to the output graph
|
||||||
|
@ -178,7 +175,10 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
||||||
debugger_->PreExecute(root_graph);
|
debugger_->PreExecute(root_graph);
|
||||||
}
|
}
|
||||||
SetSummaryNodes(root_graph.get());
|
SetSummaryNodes(root_graph.get());
|
||||||
// alloc mem
|
// Alloc memory for child graph's inputs
|
||||||
|
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||||
|
memo.clear();
|
||||||
|
// Alloc memory for root graph's inputs and node's outputs, workspace
|
||||||
MemoryAlloc(root_graph.get());
|
MemoryAlloc(root_graph.get());
|
||||||
// generate and load task into device
|
// generate and load task into device
|
||||||
Load(root_graph);
|
Load(root_graph);
|
||||||
|
|
|
@ -337,6 +337,8 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
||||||
if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
|
if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
|
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
|
||||||
}
|
}
|
||||||
|
MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope()
|
||||||
|
<< " index: " << index << " size: " << tensor_size;
|
||||||
AnfAlgo::SetOutputAddr(address, index, item.get());
|
AnfAlgo::SetOutputAddr(address, index, item.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue