forked from mindspore-Ecosystem/mindspore
Ignore useless output when generate stub tensor in BuildOpsInGraph
This commit is contained in:
parent
8fc5962418
commit
3d098708b9
|
@ -156,6 +156,7 @@ TensorPtr GetCNodeOutputStubTensor(const KernelWithIndex &kernel_with_index,
|
|||
}
|
||||
|
||||
void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr &kernel,
|
||||
const std::map<KernelWithIndex, size_t> &cnode_refcount,
|
||||
std::map<KernelWithIndex, OutputTensorInfo> *op_output_info) {
|
||||
MS_EXCEPTION_IF_NULL(single_op_graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
|
@ -163,6 +164,10 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr
|
|||
OutputTensorInfo output_tensor_info;
|
||||
size_t out_idx = 0;
|
||||
for (const auto &output : single_op_graph->outputs()) {
|
||||
KernelWithIndex kernel_with_index = std::make_pair(kernel, out_idx++);
|
||||
if (cnode_refcount.find(kernel_with_index) == cnode_refcount.end()) {
|
||||
continue;
|
||||
}
|
||||
const auto &output_kernel_with_index = AnfAlgo::VisitKernel(output, 0);
|
||||
const auto &output_node = output_kernel_with_index.first;
|
||||
const auto &output_index = output_kernel_with_index.second;
|
||||
|
@ -187,7 +192,6 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr
|
|||
device::DeviceAddressPtr device_address =
|
||||
std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, 0, output_format, output_type);
|
||||
stub_output_tensor->set_device_address(device_address);
|
||||
KernelWithIndex kernel_with_index = std::make_pair(kernel, out_idx++);
|
||||
output_tensor_info.output_stub_tensor = stub_output_tensor;
|
||||
output_tensor_info.is_weight = !dynamic_cast<device::KernelInfo *>(output_node->kernel_info())->is_feature_map();
|
||||
(*op_output_info)[kernel_with_index] = output_tensor_info;
|
||||
|
@ -700,7 +704,8 @@ void AscendSession::GetOpInputStubTensors(const CNodePtr &cnode, const std::map<
|
|||
}
|
||||
|
||||
void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs) {
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs,
|
||||
const std::map<KernelWithIndex, size_t> &cnode_refcount) {
|
||||
if (built_graph_id_.find(graph_id) != built_graph_id_.end()) {
|
||||
return;
|
||||
}
|
||||
|
@ -722,13 +727,13 @@ void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfN
|
|||
if (single_op_graph_iter != run_op_graphs_.end()) {
|
||||
// if graph of same single op exists, the output tensor of current op should be generated
|
||||
const auto &single_op_graph = single_op_graph_iter->second;
|
||||
GenOpOutputStubTensor(single_op_graph, kernel, &op_output_info);
|
||||
GenOpOutputStubTensor(single_op_graph, kernel, cnode_refcount, &op_output_info);
|
||||
continue;
|
||||
}
|
||||
const auto &single_op_graph =
|
||||
PreBuildOp(op_run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(single_op_graph);
|
||||
GenOpOutputStubTensor(single_op_graph, kernel, &op_output_info);
|
||||
GenOpOutputStubTensor(single_op_graph, kernel, cnode_refcount, &op_output_info);
|
||||
opt::HideNopNode(single_op_graph.get());
|
||||
// The graph info could have been changed in PreBuildOp
|
||||
const GraphInfo &new_graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
|
||||
|
|
|
@ -59,7 +59,8 @@ class AscendSession : public SessionBasic {
|
|||
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
|
||||
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
|
||||
void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs) override;
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs,
|
||||
const std::map<KernelWithIndex, size_t> &cnode_refcount) override;
|
||||
|
||||
private:
|
||||
// compile child graph when session have multiple child graphs
|
||||
|
|
|
@ -2156,9 +2156,9 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
|
|||
GetParameterIndex(kernel_graph.get(), inputs, ¶meter_index);
|
||||
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
|
||||
CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes);
|
||||
std::map<KernelWithIndex, size_t> cnode_ref;
|
||||
GetRefCount(kernel_graph.get(), &cnode_ref);
|
||||
BuildOpsInGraph(graph_id, parameter_index, inputs);
|
||||
std::map<KernelWithIndex, size_t> cnode_refcount;
|
||||
GetRefCount(kernel_graph.get(), &cnode_refcount);
|
||||
BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
|
||||
|
||||
std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
|
||||
for (const auto &kernel : kernel_graph->execution_order()) {
|
||||
|
@ -2177,8 +2177,8 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
|
|||
input_tensor_info.input_tensors_mask);
|
||||
|
||||
// Handle inputs and outputs of current op
|
||||
HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map);
|
||||
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_ref, &op_output_map, outputs);
|
||||
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
|
||||
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs);
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
|
|
@ -178,7 +178,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
const std::vector<int64_t> &tensors_mask) {}
|
||||
void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs) {}
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs,
|
||||
const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
|
||||
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
|
||||
|
||||
virtual void SetSummaryNodes(KernelGraph *graph);
|
||||
|
|
Loading…
Reference in New Issue