!21507 fix insert cast in cpu pynative mode

Merge pull request !21507 from baihuawei/insert_cast_fix
This commit is contained in:
i-robot 2021-08-09 16:52:58 +00:00 committed by Gitee
commit ef2d3ad850
2 changed files with 9 additions and 2 deletions

View File

@ -693,6 +693,9 @@ void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNode
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
auto input_node = cnode->input(input_idx);
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<CNode>() && AnfAlgo::GetInputTensorNum(input_node) == 0) {
continue;
}
output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
}
} else {

View File

@ -91,7 +91,7 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
}
}
void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &func_output) {
MS_EXCEPTION_IF_NULL(cnode);
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t i = 0; i < output_num; i++) {
@ -102,6 +102,9 @@ void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, i);
for (size_t j = 0; j < used_node_list->size(); j++) {
auto used_node = used_node_list->at(j).first;
if (used_node != func_output) {
continue;
}
auto used_node_index = static_cast<size_t>(used_node_list->at(j).second - 1);
auto cur_input = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(used_node), used_node_index);
const std::vector<size_t> origin_shape =
@ -128,10 +131,11 @@ bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) {
}
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
auto func_output = func_graph->output();
for (auto node : outputs) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
auto cnode = node->cast<CNodePtr>();
InsertCastForGraphOutput(func_graph, cnode);
InsertCastForGraphOutput(func_graph, cnode, func_output);
}
}
return true;