forked from OSSInnovation/mindspore
!3333 Fix internal multiple outputs check
Merge pull request !3333 from YuJianfeng/master
This commit is contained in:
commit
24dd2fc2a6
|
@ -153,7 +153,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
||||
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
|
||||
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
|
||||
}
|
||||
make_tuple_inputs.emplace_back(trans_op);
|
||||
|
@ -265,7 +265,7 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
// Single output
|
||||
if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) {
|
||||
auto new_node = InsertTransOpForSingleOutput(func_graph, node, kernel_select);
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, 0)) {
|
||||
kernel_graph->ReplaceInternalOutput(node, new_node);
|
||||
}
|
||||
return new_node;
|
||||
|
|
|
@ -65,7 +65,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
|
|||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode)) {
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
|
||||
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
|
||||
}
|
||||
} else {
|
||||
|
@ -114,7 +114,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
|
|||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode)) {
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, 0)) {
|
||||
kernel_graph->ReplaceInternalOutput(cnode, replace_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, c
|
|||
if (kernel_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!kernel_graph->IsInternalOutput(node)) {
|
||||
if (!kernel_graph->IsInternalOutput(node, 0)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!UsedForOutputOnly(func_graph, node)) {
|
||||
|
|
|
@ -1021,26 +1021,16 @@ AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_nod
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
|
||||
if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return;
|
||||
}
|
||||
(void)final_output_kernels_.insert(node);
|
||||
}
|
||||
|
||||
bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (final_output_kernels_.find(node) != final_output_kernels_.end()) {
|
||||
return true;
|
||||
bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, int output_idx) const {
|
||||
auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
|
||||
if (front_nodes_iter != internal_outputs_to_front_map_.end()) {
|
||||
if (output_idx == -1) {
|
||||
return true;
|
||||
}
|
||||
auto &front_nodes = front_nodes_iter->second;
|
||||
if (front_nodes.find(output_idx) != front_nodes.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -153,9 +153,7 @@ class KernelGraph : public FuncGraph {
|
|||
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1,
|
||||
int dst_output_idx = -1);
|
||||
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
|
||||
bool IsInternalOutput(const AnfNodePtr &node) const;
|
||||
void AddFinalOutputKernel(const AnfNodePtr &node);
|
||||
bool IsFinalOutputKernel(const AnfNodePtr &node) const;
|
||||
bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const;
|
||||
uint32_t current_epoch() const { return current_epoch_; }
|
||||
void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
|
||||
void UpdateChildGraphOrder();
|
||||
|
@ -230,7 +228,6 @@ class KernelGraph : public FuncGraph {
|
|||
bool null_output_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
|
||||
std::unordered_map<AnfNodePtr, std::unordered_map<int, AnfNodePtr>> internal_outputs_to_front_map_;
|
||||
std::set<AnfNodePtr> final_output_kernels_;
|
||||
uint32_t current_epoch_;
|
||||
};
|
||||
} // namespace session
|
||||
|
|
|
@ -89,7 +89,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
TypeId type_id = kNumberTypeFloat32;
|
||||
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
|
||||
std::vector<int> temp_shape;
|
||||
if (graph.IsInternalOutput(node)) {
|
||||
if (graph.IsInternalOutput(node, output_index)) {
|
||||
temp_shape.emplace_back(1);
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
tensor->set_device_address(address);
|
||||
|
@ -307,8 +307,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
|
|||
auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx);
|
||||
auto ref_real_node = real_kernel.first;
|
||||
auto ref_real_node_index = real_kernel.second;
|
||||
if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) &&
|
||||
node_graph->IsFinalOutputKernel(ref_real_node)) {
|
||||
if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node, ref_real_node_index)) {
|
||||
auto kernel_info = ref_real_node->kernel_info();
|
||||
if (kernel_info == nullptr || !kernel_info->has_build_info()) {
|
||||
MS_LOG(INFO) << "No kernel info";
|
||||
|
|
|
@ -492,6 +492,10 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
|||
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown &&
|
||||
AnfAlgo::OutputAddrExist(real_input_node, 0)) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
||||
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
||||
builder->SetOutputsFormat(output_format);
|
||||
|
|
|
@ -355,7 +355,6 @@ void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) {
|
|||
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
|
||||
continue;
|
||||
}
|
||||
graph->AddFinalOutputKernel(item_with_index.first);
|
||||
if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
|
||||
AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
|
||||
} else {
|
||||
|
|
|
@ -122,7 +122,7 @@ TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_sing
|
|||
|
||||
auto make_tuple = GetMakeTuple(kg);
|
||||
auto trans_data = make_tuple->cast<CNodePtr>()->input(1);
|
||||
EXPECT_TRUE(kg->IsInternalOutput(trans_data));
|
||||
EXPECT_TRUE(kg->IsInternalOutput(trans_data, 0));
|
||||
|
||||
// remove trans op for internal output
|
||||
auto graph_optimizer1 = std::make_shared<opt::GraphOptimizer>();
|
||||
|
@ -156,8 +156,8 @@ TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_mult
|
|||
auto make_tuple1 = tuple_getitem->cast<CNodePtr>()->input(1);
|
||||
auto trans_data1 = make_tuple1->cast<CNodePtr>()->input(1);
|
||||
auto trans_data2 = make_tuple1->cast<CNodePtr>()->input(2);
|
||||
EXPECT_TRUE(kg->IsInternalOutput(trans_data1));
|
||||
EXPECT_TRUE(kg->IsInternalOutput(trans_data2));
|
||||
EXPECT_TRUE(kg->IsInternalOutput(trans_data1, 0));
|
||||
EXPECT_TRUE(kg->IsInternalOutput(trans_data2, 0));
|
||||
|
||||
// remove trans op for internal output
|
||||
auto graph_optimizer1 = std::make_shared<opt::GraphOptimizer>();
|
||||
|
|
Loading…
Reference in New Issue