Fix the unused parameter elimination

This commit is contained in:
yujianfeng 2024-02-19 11:48:53 +08:00 committed by r1chardf1d0
parent ec03b6d9e6
commit 4e8aff4f0e
1 changed files with 3 additions and 10 deletions

View File

@ -440,19 +440,12 @@ void FuncGraphBuilder::EraseUnusedParameter() {
}
const auto &nodes = graph_->TopoSort(graph_->output());
std::unordered_set<AnfNodePtr> used_params;
for (auto node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
const auto &cnode_inputs = cnode->inputs();
(void)std::copy_if(cnode_inputs.begin(), cnode_inputs.end(), std::inserter(used_params, used_params.begin()),
[](const AnfNodePtr &input) { return input->isa<Parameter>(); });
}
(void)std::copy_if(nodes.begin(), nodes.end(), std::inserter(used_params, used_params.begin()),
[](const AnfNodePtr &node) { return node->isa<Parameter>(); });
std::vector<AnfNodePtr> new_params;
const auto &origin_params = graph_->parameters();
(void)std::copy_if(origin_params.begin(), origin_params.end(), std::back_inserter(new_params),
[&used_params](const AnfNodePtr param) { return used_params.find(param) != used_params.end(); });
[&used_params](const AnfNodePtr &param) { return used_params.find(param) != used_params.end(); });
graph_->set_parameters(new_params);
}