forked from mindspore-Ecosystem/mindspore
!9507 Fix bug in scenario that one tensor for multiple graph output in pynative bp graph
From: @HulkTang Reviewed-by: @chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
0d21d5b570
|
@ -128,7 +128,7 @@ void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
|
|||
BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<size_t> &indexes,
|
||||
std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
|
||||
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
|
||||
auto &node = node_output_pair.first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -152,7 +152,7 @@ BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_
|
|||
}
|
||||
MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
|
||||
}
|
||||
(*output_indexes)[node_output_pair] = indexes;
|
||||
(*output_indexes)[node_output_pair].emplace_back(indexes);
|
||||
BaseRef output_placeholder = std::make_shared<BaseRef>();
|
||||
return output_placeholder;
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_
|
|||
BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<size_t> &indexes,
|
||||
std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
|
||||
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(output_indexes);
|
||||
MS_LOG(INFO) << "Create placeholder for output[" << anf->DebugString() << "]";
|
||||
|
@ -189,7 +189,8 @@ BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr
|
|||
}
|
||||
|
||||
void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
VectorRef *outputs, std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
|
||||
VectorRef *outputs,
|
||||
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(output_indexes);
|
||||
|
@ -333,7 +334,7 @@ void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<Kern
|
|||
}
|
||||
|
||||
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
||||
const std::map<KernelWithIndex, std::vector<size_t>> &output_indexes,
|
||||
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
|
||||
const std::map<KernelWithIndex, size_t> &ref_count,
|
||||
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) {
|
||||
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
|
||||
|
@ -350,19 +351,25 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
|||
if (iter == output_indexes.end()) {
|
||||
continue;
|
||||
}
|
||||
const std::vector<size_t> &ref_indexes = iter->second;
|
||||
size_t n = 0;
|
||||
const VectorRef *cur_vector_ref = outputs;
|
||||
while (n != ref_indexes.size() - 1) {
|
||||
size_t index = ref_indexes.at(n++);
|
||||
const BaseRef &base_ref = (*cur_vector_ref)[index];
|
||||
if (!utils::isa<VectorRef>(base_ref)) {
|
||||
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, indexes: " << ref_indexes << "cur n: " << n - 1;
|
||||
const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
|
||||
for (const auto &ref_indexes : multiple_ref_indexes) {
|
||||
size_t n = 0;
|
||||
const VectorRef *cur_vector_ref = outputs;
|
||||
for (; n < ref_indexes.size() - 1; n += 1) {
|
||||
size_t index = ref_indexes.at(n);
|
||||
if (index >= cur_vector_ref->size()) {
|
||||
MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
|
||||
<< cur_vector_ref->size();
|
||||
}
|
||||
const BaseRef &base_ref = (*cur_vector_ref)[index];
|
||||
if (!utils::isa<VectorRef>(base_ref)) {
|
||||
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
|
||||
}
|
||||
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
|
||||
}
|
||||
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
|
||||
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
|
||||
tensor_ref = output_tensor;
|
||||
}
|
||||
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
|
||||
tensor_ref = output_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -725,7 +732,7 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector
|
|||
auto kernel_graph = GetGraph(graph_id);
|
||||
std::map<AnfNodePtr, size_t> parameter_index;
|
||||
GetParameterIndex(kernel_graph.get(), inputs, ¶meter_index);
|
||||
std::map<KernelWithIndex, std::vector<size_t>> output_indexes;
|
||||
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
|
||||
CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes);
|
||||
std::map<KernelWithIndex, size_t> cnode_ref;
|
||||
GetRefCount(kernel_graph.get(), &cnode_ref);
|
||||
|
|
Loading…
Reference in New Issue