!3333 Fix internal multiple outputs check

Merge pull request !3333 from YuJianfeng/master
This commit is contained in:
mindspore-ci-bot 2020-07-23 11:38:11 +08:00 committed by Gitee
commit 24dd2fc2a6
9 changed files with 25 additions and 36 deletions

View File

@ -153,7 +153,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
}
make_tuple_inputs.emplace_back(trans_op);
@ -265,7 +265,7 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP
// Single output
if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) {
auto new_node = InsertTransOpForSingleOutput(func_graph, node, kernel_select);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, 0)) {
kernel_graph->ReplaceInternalOutput(node, new_node);
}
return new_node;

View File

@ -65,7 +65,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode)) {
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
}
} else {
@ -114,7 +114,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode)) {
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, 0)) {
kernel_graph->ReplaceInternalOutput(cnode, replace_node);
}
}

View File

@ -58,7 +58,7 @@ const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, c
if (kernel_graph == nullptr) {
return nullptr;
}
if (!kernel_graph->IsInternalOutput(node)) {
if (!kernel_graph->IsInternalOutput(node, 0)) {
return nullptr;
}
if (!UsedForOutputOnly(func_graph, node)) {

View File

@ -1021,27 +1021,17 @@ AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_nod
return nullptr;
}
bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) {
bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, int output_idx) const {
auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
if (front_nodes_iter != internal_outputs_to_front_map_.end()) {
if (output_idx == -1) {
return true;
}
return false;
}
void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) {
if (node == nullptr) {
return;
}
(void)final_output_kernels_.insert(node);
}
bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const {
if (node == nullptr) {
return false;
}
if (final_output_kernels_.find(node) != final_output_kernels_.end()) {
auto &front_nodes = front_nodes_iter->second;
if (front_nodes.find(output_idx) != front_nodes.end()) {
return true;
}
}
return false;
}

View File

@ -153,9 +153,7 @@ class KernelGraph : public FuncGraph {
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1,
int dst_output_idx = -1);
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
bool IsInternalOutput(const AnfNodePtr &node) const;
void AddFinalOutputKernel(const AnfNodePtr &node);
bool IsFinalOutputKernel(const AnfNodePtr &node) const;
bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const;
uint32_t current_epoch() const { return current_epoch_; }
void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
void UpdateChildGraphOrder();
@ -230,7 +228,6 @@ class KernelGraph : public FuncGraph {
bool null_output_;
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
std::unordered_map<AnfNodePtr, std::unordered_map<int, AnfNodePtr>> internal_outputs_to_front_map_;
std::set<AnfNodePtr> final_output_kernels_;
uint32_t current_epoch_;
};
} // namespace session

View File

@ -89,7 +89,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
TypeId type_id = kNumberTypeFloat32;
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
std::vector<int> temp_shape;
if (graph.IsInternalOutput(node)) {
if (graph.IsInternalOutput(node, output_index)) {
temp_shape.emplace_back(1);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_device_address(address);
@ -307,8 +307,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx);
auto ref_real_node = real_kernel.first;
auto ref_real_node_index = real_kernel.second;
if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) &&
node_graph->IsFinalOutputKernel(ref_real_node)) {
if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node, ref_real_node_index)) {
auto kernel_info = ref_real_node->kernel_info();
if (kernel_info == nullptr || !kernel_info->has_build_info()) {
MS_LOG(INFO) << "No kernel info";

View File

@ -492,6 +492,10 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown &&
AnfAlgo::OutputAddrExist(real_input_node, 0)) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);

View File

@ -355,7 +355,6 @@ void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) {
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
continue;
}
graph->AddFinalOutputKernel(item_with_index.first);
if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
} else {

View File

@ -122,7 +122,7 @@ TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_sing
auto make_tuple = GetMakeTuple(kg);
auto trans_data = make_tuple->cast<CNodePtr>()->input(1);
EXPECT_TRUE(kg->IsInternalOutput(trans_data));
EXPECT_TRUE(kg->IsInternalOutput(trans_data, 0));
// remove trans op for internal output
auto graph_optimizer1 = std::make_shared<opt::GraphOptimizer>();
@ -156,8 +156,8 @@ TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_mult
auto make_tuple1 = tuple_getitem->cast<CNodePtr>()->input(1);
auto trans_data1 = make_tuple1->cast<CNodePtr>()->input(1);
auto trans_data2 = make_tuple1->cast<CNodePtr>()->input(2);
EXPECT_TRUE(kg->IsInternalOutput(trans_data1));
EXPECT_TRUE(kg->IsInternalOutput(trans_data2));
EXPECT_TRUE(kg->IsInternalOutput(trans_data1, 0));
EXPECT_TRUE(kg->IsInternalOutput(trans_data2, 0));
// remove trans op for internal output
auto graph_optimizer1 = std::make_shared<opt::GraphOptimizer>();