!9507 Fix bug in scenario that one tensor for multiple graph output in pynative bp graph

From: @HulkTang
Reviewed-by: @chujinjin
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2020-12-07 10:04:43 +08:00 committed by Gitee
commit 0d21d5b570
1 changed files with 24 additions and 17 deletions

View File

@ -128,7 +128,7 @@ void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<size_t> &indexes,
std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
auto &node = node_output_pair.first;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
@ -152,7 +152,7 @@ BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_
}
MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
}
(*output_indexes)[node_output_pair] = indexes;
(*output_indexes)[node_output_pair].emplace_back(indexes);
BaseRef output_placeholder = std::make_shared<BaseRef>();
return output_placeholder;
}
@ -160,7 +160,7 @@ BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_
BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<size_t> &indexes,
std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(output_indexes);
MS_LOG(INFO) << "Create placeholder for output[" << anf->DebugString() << "]";
@ -189,7 +189,8 @@ BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr
}
void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs, std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
VectorRef *outputs,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(output_indexes);
@ -333,7 +334,7 @@ void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<Kern
}
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, std::vector<size_t>> &output_indexes,
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) {
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
@ -350,19 +351,25 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
if (iter == output_indexes.end()) {
continue;
}
const std::vector<size_t> &ref_indexes = iter->second;
size_t n = 0;
const VectorRef *cur_vector_ref = outputs;
while (n != ref_indexes.size() - 1) {
size_t index = ref_indexes.at(n++);
const BaseRef &base_ref = (*cur_vector_ref)[index];
if (!utils::isa<VectorRef>(base_ref)) {
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, indexes: " << ref_indexes << "cur n: " << n - 1;
const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
for (const auto &ref_indexes : multiple_ref_indexes) {
size_t n = 0;
const VectorRef *cur_vector_ref = outputs;
for (; n < ref_indexes.size() - 1; n += 1) {
size_t index = ref_indexes.at(n);
if (index >= cur_vector_ref->size()) {
MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
<< cur_vector_ref->size();
}
const BaseRef &base_ref = (*cur_vector_ref)[index];
if (!utils::isa<VectorRef>(base_ref)) {
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
}
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
}
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
}
}
@ -725,7 +732,7 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector
auto kernel_graph = GetGraph(graph_id);
std::map<AnfNodePtr, size_t> parameter_index;
GetParameterIndex(kernel_graph.get(), inputs, &parameter_index);
std::map<KernelWithIndex, std::vector<size_t>> output_indexes;
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes);
std::map<KernelWithIndex, size_t> cnode_ref;
GetRefCount(kernel_graph.get(), &cnode_ref);