!16963 Code clean
From: @HulkTang Reviewed-by: @kisnwang,@jjfeing Signed-off-by: @jjfeing
This commit is contained in:
commit
7bf41f18ed
|
@ -172,7 +172,6 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr
|
|||
tensor::TensorPtr stub_output_tensor =
|
||||
std::make_shared<tensor::Tensor>(infer_type, tensor_abstract->shape()->shape(), nullptr);
|
||||
const auto &output_type = AnfAlgo::GetOutputDeviceDataType(output_node, output_index);
|
||||
const auto &output_shape = AnfAlgo::GetOutputDeviceShape(output_node, output_index);
|
||||
const auto &output_format = AnfAlgo::GetOutputFormat(output_node, output_index);
|
||||
tensor::DeviceInfo device_info;
|
||||
device_info.format_ = output_format;
|
||||
|
@ -707,7 +706,7 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g
|
|||
return;
|
||||
}
|
||||
|
||||
const auto &graph = PreBuildOp(op_run_info, graph_info, input_tensors, tensors_mask);
|
||||
const auto &graph = PreBuildOp(op_run_info, input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// init runtime resource
|
||||
InitRuntimeResource();
|
||||
|
@ -758,7 +757,7 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
|
|||
MS_LOG(INFO) << "Run op " << op_run_info->op_name << " finish!";
|
||||
}
|
||||
|
||||
KernelGraphPtr AscendSession::PreBuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
KernelGraphPtr AscendSession::PreBuildOp(const OpRunInfo &op_run_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
// Construct graph include one op
|
||||
|
@ -816,7 +815,7 @@ void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfN
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::map<KernelWithIndex, OutputTensorInfo> op_output_info;
|
||||
std::vector<CNodePtr> kernels;
|
||||
std::unordered_map<KernelGraphPtr, std::vector<GraphInfo>> single_op_graphs;
|
||||
std::unordered_map<KernelGraphPtr, GraphInfo> single_op_graphs;
|
||||
// Collect kernels need to be built in single op graphs
|
||||
for (const auto &kernel : graph->execution_order()) {
|
||||
// Generate fake input tensors, tensor masks and input kernel with index
|
||||
|
@ -837,13 +836,13 @@ void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfN
|
|||
continue;
|
||||
}
|
||||
const auto &single_op_graph =
|
||||
PreBuildOp(op_run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask);
|
||||
PreBuildOp(op_run_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(single_op_graph);
|
||||
GenOpOutputStubTensor(single_op_graph, kernel, cnode_refcount, &op_output_info);
|
||||
opt::HideNopNode(single_op_graph.get());
|
||||
// The graph info could have been changed in PreBuildOp
|
||||
const GraphInfo &new_graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
|
||||
single_op_graphs.insert({single_op_graph, {graph_info, new_graph_info}});
|
||||
single_op_graphs.emplace(single_op_graph, new_graph_info);
|
||||
const auto &execution_order = single_op_graph->execution_order();
|
||||
std::copy(execution_order.begin(), execution_order.end(), std::back_inserter(kernels));
|
||||
}
|
||||
|
@ -861,10 +860,8 @@ void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfN
|
|||
// Record single op graphs in run_op_graphs_ so that these graphs can be reused in BuildOpImpl
|
||||
for (const auto &graph_item : single_op_graphs) {
|
||||
RunOpMemoryClear(graph_item.first.get());
|
||||
for (const auto &graph_info : graph_item.second) {
|
||||
run_op_graphs_[graph_info] = graph_item.first;
|
||||
MS_LOG(DEBUG) << "Pre build op finished, graph info: " << graph_info;
|
||||
}
|
||||
run_op_graphs_[graph_item.second] = graph_item.first;
|
||||
MS_LOG(DEBUG) << "Pre build op finished, graph info: " << graph_item.second;
|
||||
}
|
||||
built_graph_id_.insert(graph_id);
|
||||
}
|
||||
|
|
|
@ -117,8 +117,7 @@ class AscendSession : public SessionBasic {
|
|||
void LoadGraphsToDbg(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
|
||||
KernelGraphPtr PreBuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
KernelGraphPtr PreBuildOp(const OpRunInfo &op_run_info, const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask);
|
||||
void GetOpInputStubTensors(const CNodePtr &cnode, const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs,
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <queue>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ir/manager.h"
|
||||
|
@ -481,14 +482,12 @@ void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<Kern
|
|||
}
|
||||
|
||||
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
||||
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
|
||||
const std::map<KernelWithIndex, size_t> &ref_count,
|
||||
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs,
|
||||
std::vector<TensorPtr> *runop_output_tensors) {
|
||||
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, GraphOutputInfo *graph_output_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(op_output_map);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(runop_output_tensors);
|
||||
MS_EXCEPTION_IF_NULL(graph_output_info);
|
||||
MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
|
||||
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
|
||||
if (output_tensors.size() > op_outputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
|
||||
|
@ -499,14 +498,14 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
|||
if (ref_count.find(kernel_with_index) != ref_count.end()) {
|
||||
(*op_output_map)[kernel_with_index] = output_tensor;
|
||||
}
|
||||
const auto &iter = output_indexes.find(kernel_with_index);
|
||||
if (iter == output_indexes.end()) {
|
||||
const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
|
||||
if (iter == graph_output_info->output_indexes.end()) {
|
||||
continue;
|
||||
}
|
||||
const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
|
||||
for (const auto &ref_indexes : multiple_ref_indexes) {
|
||||
size_t n = 0;
|
||||
const VectorRef *cur_vector_ref = outputs;
|
||||
const VectorRef *cur_vector_ref = graph_output_info->graph_outputs;
|
||||
for (; n < ref_indexes.size() - 1; n += 1) {
|
||||
size_t index = ref_indexes.at(n);
|
||||
if (index >= cur_vector_ref->size()) {
|
||||
|
@ -521,7 +520,7 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
|||
}
|
||||
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
|
||||
tensor_ref = output_tensor;
|
||||
runop_output_tensors->emplace_back(output_tensor);
|
||||
graph_output_info->graph_output_tensors.emplace_back(output_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2136,8 +2135,9 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
std::map<AnfNodePtr, size_t> parameter_index;
|
||||
GetParameterIndex(kernel_graph.get(), inputs, ¶meter_index);
|
||||
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
|
||||
CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes);
|
||||
GraphOutputInfo graph_output_info;
|
||||
graph_output_info.graph_outputs = outputs;
|
||||
CreateOutputPlaceholder(kernel_graph, inputs, graph_output_info.graph_outputs, &graph_output_info.output_indexes);
|
||||
std::map<KernelWithIndex, size_t> cnode_refcount;
|
||||
GetRefCount(kernel_graph.get(), &cnode_refcount);
|
||||
BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
|
||||
|
@ -2163,14 +2163,12 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
|
|||
RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
|
||||
input_tensor_info.input_tensors_mask);
|
||||
|
||||
std::vector<tensor::TensorPtr> new_output_tensors;
|
||||
|
||||
// Handle inputs and outputs of current op
|
||||
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
|
||||
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs, &new_output_tensors);
|
||||
HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info);
|
||||
// Save grad node to Bucket
|
||||
if (kernel_graph->is_bprop()) {
|
||||
AddGradAddrToBucket(graph_id, new_output_tensors);
|
||||
AddGradAddrToBucket(graph_id, graph_output_info.graph_output_tensors);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
|
|
|
@ -78,6 +78,12 @@ struct OutputTensorInfo {
|
|||
bool is_weight;
|
||||
};
|
||||
|
||||
struct GraphOutputInfo {
|
||||
VectorRef *graph_outputs;
|
||||
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
|
||||
std::vector<tensor::TensorPtr> graph_output_tensors;
|
||||
};
|
||||
|
||||
using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
|
||||
using KernelMapTensor = std::map<session::KernelWithIndex, BaseRef, session::KernelWithIndexCmp>;
|
||||
class Executor;
|
||||
|
|
Loading…
Reference in New Issue