forked from mindspore-Ecosystem/mindspore
!21507 fix insert cast in cpu pynative mode
Merge pull request !21507 from baihuawei/insert_cast_fix
This commit is contained in:
commit
ef2d3ad850
|
@ -693,6 +693,9 @@ void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNode
|
||||||
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
|
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
|
||||||
auto input_node = cnode->input(input_idx);
|
auto input_node = cnode->input(input_idx);
|
||||||
MS_EXCEPTION_IF_NULL(input_node);
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
|
if (input_node->isa<CNode>() && AnfAlgo::GetInputTensorNum(input_node) == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
|
output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -91,7 +91,7 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &func_output) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
|
@ -102,6 +102,9 @@ void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn
|
||||||
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, i);
|
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, i);
|
||||||
for (size_t j = 0; j < used_node_list->size(); j++) {
|
for (size_t j = 0; j < used_node_list->size(); j++) {
|
||||||
auto used_node = used_node_list->at(j).first;
|
auto used_node = used_node_list->at(j).first;
|
||||||
|
if (used_node != func_output) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto used_node_index = static_cast<size_t>(used_node_list->at(j).second - 1);
|
auto used_node_index = static_cast<size_t>(used_node_list->at(j).second - 1);
|
||||||
auto cur_input = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(used_node), used_node_index);
|
auto cur_input = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(used_node), used_node_index);
|
||||||
const std::vector<size_t> origin_shape =
|
const std::vector<size_t> origin_shape =
|
||||||
|
@ -128,10 +131,11 @@ bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) {
|
||||||
}
|
}
|
||||||
AnfNodePtrList outputs;
|
AnfNodePtrList outputs;
|
||||||
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
|
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
|
||||||
|
auto func_output = func_graph->output();
|
||||||
for (auto node : outputs) {
|
for (auto node : outputs) {
|
||||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
InsertCastForGraphOutput(func_graph, cnode);
|
InsertCastForGraphOutput(func_graph, cnode, func_output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
Loading…
Reference in New Issue