!2270 fix code review

Merge pull request !2270 from lianliguang/master
This commit is contained in:
mindspore-ci-bot 2020-06-18 16:15:05 +08:00 committed by Gitee
commit e32d539b5f
1 changed files with 11 additions and 3 deletions

View File

@ -31,8 +31,9 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) {
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0);
AnfNodePtr cur_node = kernel_with_index.first; AnfNodePtr cur_node = kernel_with_index.first;
size_t cur_out_index = kernel_with_index.second; size_t cur_out_index = kernel_with_index.second;
MS_EXCEPTION_IF_NULL(cur_node);
if (cur_node->isa<CNode>()) { if (cur_node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = cur_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
std::string op_name = AnfAlgo::GetCNodeName(cnode); std::string op_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
@ -88,7 +89,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
size_t input_index, const AnfNodePtr &get_item) { size_t input_index, const AnfNodePtr &get_item) {
AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item);
size_t final_index = output_index; size_t final_index = output_index;
AnfNodePtr input_node = cnode->input(input_index + 1); AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index);
session::KernelWithIndex origin_pair; session::KernelWithIndex origin_pair;
origin_pair = FindRefOriginNode(input_node); origin_pair = FindRefOriginNode(input_node);
MS_EXCEPTION_IF_NULL(origin_pair.first); MS_EXCEPTION_IF_NULL(origin_pair.first);
@ -133,6 +134,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
} }
AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) { const std::shared_ptr<kernel::OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos(); auto ref_infos = op_info->ref_infos();
std::vector<AnfNodePtr> make_tuple_inputs; std::vector<AnfNodePtr> make_tuple_inputs;
AbstractBasePtrList abstract_list; AbstractBasePtrList abstract_list;
@ -144,9 +146,11 @@ AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodeP
auto input_index = ref_infos.at(output_index); auto input_index = ref_infos.at(output_index);
final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node); final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node);
} }
MS_EXCEPTION_IF_NULL(final_node);
abstract_list.push_back(final_node->abstract()); abstract_list.push_back(final_node->abstract());
make_tuple_inputs.push_back(final_node); make_tuple_inputs.push_back(final_node);
} }
MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple); MS_EXCEPTION_IF_NULL(make_tuple);
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
@ -155,6 +159,8 @@ AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodeP
AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) { const std::shared_ptr<kernel::OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos(); auto ref_infos = op_info->ref_infos();
for (const auto &ref_info : ref_infos) { for (const auto &ref_info : ref_infos) {
if (ref_info.second > cnode->inputs().size()) { if (ref_info.second > cnode->inputs().size()) {
@ -206,7 +212,9 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
return nullptr; return nullptr;
} }
if (op_info->is_ref()) { if (op_info->is_ref()) {
if (!cnode->Type()->isa<Tuple>()) { auto type = cnode->Type();
MS_EXCEPTION_IF_NULL(type);
if (!type->isa<Tuple>()) {
return DealRefSigleOutput(graph, cnode, op_info); return DealRefSigleOutput(graph, cnode, op_info);
} else { } else {
return DealRefForMultipleOutput(graph, cnode, op_info); return DealRefForMultipleOutput(graph, cnode, op_info);