Fix the unused parameter elimination
This commit is contained in:
parent
ec03b6d9e6
commit
4e8aff4f0e
|
@ -440,19 +440,12 @@ void FuncGraphBuilder::EraseUnusedParameter() {
|
|||
}
|
||||
const auto &nodes = graph_->TopoSort(graph_->output());
|
||||
std::unordered_set<AnfNodePtr> used_params;
|
||||
for (auto node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
const auto &cnode_inputs = cnode->inputs();
|
||||
(void)std::copy_if(cnode_inputs.begin(), cnode_inputs.end(), std::inserter(used_params, used_params.begin()),
|
||||
[](const AnfNodePtr &input) { return input->isa<Parameter>(); });
|
||||
}
|
||||
(void)std::copy_if(nodes.begin(), nodes.end(), std::inserter(used_params, used_params.begin()),
|
||||
[](const AnfNodePtr &node) { return node->isa<Parameter>(); });
|
||||
std::vector<AnfNodePtr> new_params;
|
||||
const auto &origin_params = graph_->parameters();
|
||||
(void)std::copy_if(origin_params.begin(), origin_params.end(), std::back_inserter(new_params),
|
||||
[&used_params](const AnfNodePtr param) { return used_params.find(param) != used_params.end(); });
|
||||
[&used_params](const AnfNodePtr ¶m) { return used_params.find(param) != used_params.end(); });
|
||||
graph_->set_parameters(new_params);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue