forked from mindspore-Ecosystem/mindspore
!27883 Fix getitem in getitem.
Merge pull request !27883 from gaoyong10/runtime_second12
This commit is contained in:
commit
f8e820a3be
|
@ -225,13 +225,6 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch outputs by control nodes.
|
|
||||||
if (AnfAlgo::IsCallNode(node)) {
|
|
||||||
const auto &control_node_output = AnfAlgo::GetAllOutputByCallNode(output_with_index);
|
|
||||||
(void)std::copy(control_node_output.begin(), control_node_output.end(), std::back_inserter(ret));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The InitDataSetQueue node has no output.
|
// The InitDataSetQueue node has no output.
|
||||||
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimInitDataSetQueue)) {
|
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimInitDataSetQueue)) {
|
||||||
return ret_empty;
|
return ret_empty;
|
||||||
|
@ -398,6 +391,7 @@ size_t AnfRuntimeAlgorithm::GetOutputNumByAbstract(const AbstractBasePtr &node_a
|
||||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||||
const auto &sub_abstracts = tuple_abstract->elements();
|
const auto &sub_abstracts = tuple_abstract->elements();
|
||||||
for (const auto &sub_abstract : sub_abstracts) {
|
for (const auto &sub_abstract : sub_abstracts) {
|
||||||
|
MS_EXCEPTION_IF_NULL(sub_abstract);
|
||||||
result += GetOutputNumByAbstract(sub_abstract);
|
result += GetOutputNumByAbstract(sub_abstract);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
@ -414,7 +408,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputByCallNode(const K
|
||||||
auto tuple_abstract = node_abstract->cast<abstract::AbstractTuplePtr>();
|
auto tuple_abstract = node_abstract->cast<abstract::AbstractTuplePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||||
const auto &sub_abstracts = tuple_abstract->elements();
|
const auto &sub_abstracts = tuple_abstract->elements();
|
||||||
if (sub_abstracts.size() <= output_with_index.second) {
|
if (GetOutputNumByAbstract(tuple_abstract) <= output_with_index.second) {
|
||||||
MS_LOG(EXCEPTION) << "Invalid index:" << output_with_index.second
|
MS_LOG(EXCEPTION) << "Invalid index:" << output_with_index.second
|
||||||
<< "for node:" << output_with_index.first->DebugString();
|
<< "for node:" << output_with_index.first->DebugString();
|
||||||
}
|
}
|
||||||
|
@ -423,9 +417,11 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputByCallNode(const K
|
||||||
// to count the number of outputs before the target in order to accurately obtain its number.
|
// to count the number of outputs before the target in order to accurately obtain its number.
|
||||||
size_t pre_output_num = 0;
|
size_t pre_output_num = 0;
|
||||||
for (size_t i = 0; i < output_with_index.second; ++i) {
|
for (size_t i = 0; i < output_with_index.second; ++i) {
|
||||||
|
MS_EXCEPTION_IF_NULL(sub_abstracts[i]);
|
||||||
pre_output_num += GetOutputNumByAbstract(sub_abstracts[i]);
|
pre_output_num += GetOutputNumByAbstract(sub_abstracts[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(sub_abstracts[output_with_index.second]);
|
||||||
size_t output_num = GetOutputNumByAbstract(sub_abstracts[output_with_index.second]);
|
size_t output_num = GetOutputNumByAbstract(sub_abstracts[output_with_index.second]);
|
||||||
std::vector<KernelWithIndex> results;
|
std::vector<KernelWithIndex> results;
|
||||||
for (size_t i = 0; i < output_num; ++i) {
|
for (size_t i = 0; i < output_num; ++i) {
|
||||||
|
|
|
@ -43,6 +43,9 @@ void ExitActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
||||||
if (data_iter != output_branch_data_.end()) {
|
if (data_iter != output_branch_data_.end()) {
|
||||||
for (auto &output_data : data_iter->second) {
|
for (auto &output_data : data_iter->second) {
|
||||||
MS_EXCEPTION_IF_NULL(output_data.second);
|
MS_EXCEPTION_IF_NULL(output_data.second);
|
||||||
|
if (output_data.first >= input_device_tensors_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid from index:" << output_data.first << " for actor:" << GetAID();
|
||||||
|
}
|
||||||
MS_EXCEPTION_IF_NULL(input_device_tensors_[output_data.first]);
|
MS_EXCEPTION_IF_NULL(input_device_tensors_[output_data.first]);
|
||||||
output_data.second->data_ = input_device_tensors_[output_data.first];
|
output_data.second->data_ = input_device_tensors_[output_data.first];
|
||||||
}
|
}
|
||||||
|
|
|
@ -528,6 +528,11 @@ void FetchAllExecutionFunction(const FuncGraphPtr &func_graph, std::set<FuncGrap
|
||||||
|
|
||||||
// Fetch all inputs of node.
|
// Fetch all inputs of node.
|
||||||
std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (HasAbstractMonad(node)) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
// The node is divided into the following types:
|
// The node is divided into the following types:
|
||||||
// 1. depend and load.
|
// 1. depend and load.
|
||||||
const auto &node_with_index =
|
const auto &node_with_index =
|
||||||
|
@ -586,18 +591,43 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
||||||
|
|
||||||
// 4. One output node.
|
// 4. One output node.
|
||||||
const auto &abstract = real_node->abstract();
|
const auto &abstract = real_node->abstract();
|
||||||
if (abstract == nullptr ||
|
if (abstract == nullptr) {
|
||||||
((!abstract->isa<abstract::AbstractTuple>()) && (!abstract->isa<abstract::AbstractCSRTensor>()))) {
|
MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString();
|
||||||
if (abstract == nullptr) {
|
results.emplace_back(AnfAlgo::VisitKernelWithReturnType(real_node, real_index));
|
||||||
MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString();
|
return results;
|
||||||
}
|
|
||||||
return {AnfAlgo::VisitKernelWithReturnType(real_node, real_index)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Abstract is Tuple.
|
// 5 Other.
|
||||||
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||||
for (size_t i = 0; i < output_num; ++i) {
|
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) {
|
||||||
results.emplace_back(real_node, i);
|
const auto &get_item_cnode = real_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(get_item_cnode);
|
||||||
|
const auto &get_item_src_node = AnfAlgo::GetTupleGetItemRealInput(get_item_cnode);
|
||||||
|
size_t get_item_src_index = AnfAlgo::GetTupleGetItemOutIndex(get_item_cnode);
|
||||||
|
|
||||||
|
// Input node of getitm is a make tuple.
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(get_item_src_node, prim::kPrimMakeTuple)) {
|
||||||
|
const auto &make_tuple_cnode = get_item_src_node->cast<CNodePtr>();
|
||||||
|
const auto &makt_tuple_inputs = make_tuple_cnode->inputs();
|
||||||
|
if (makt_tuple_inputs.size() <= get_item_src_index) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid index:" << get_item_src_index
|
||||||
|
<< " for make tuple node : " << get_item_src_node->DebugString();
|
||||||
|
}
|
||||||
|
const auto &sub_results = FetchInputNodeByNode(makt_tuple_inputs[get_item_src_index + kMakeTupleInputStartPos]);
|
||||||
|
results.insert(results.end(), sub_results.begin(), sub_results.end());
|
||||||
|
} else {
|
||||||
|
// Input node of getitm is a parameter or make tuple.
|
||||||
|
auto get_item_src_abstract = get_item_src_node->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(get_item_src_abstract);
|
||||||
|
auto real_indexs = FetchRealIndexByAbstract(get_item_src_abstract, get_item_src_index);
|
||||||
|
(void)std::transform(
|
||||||
|
real_indexs.begin(), real_indexs.end(), std::back_inserter(results),
|
||||||
|
[&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); });
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < output_num; ++i) {
|
||||||
|
results.emplace_back(real_node, i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
@ -1705,7 +1735,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
|
||||||
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
|
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
|
||||||
// If input come from the output of kernel graph belong the same group, it should not be collected in
|
// If input come from the output of kernel graph belong the same group, it should not be collected in
|
||||||
// the group inputs.
|
// the group inputs.
|
||||||
if (HasAbstractMonad(front_node_with_index.first) ||
|
if (HasAbstractMonad(front_node_with_index.first) || HasAbstractMonad(parameter) ||
|
||||||
kernel_graph_group_info->front_output_nodes_.find(front_node_with_index) !=
|
kernel_graph_group_info->front_output_nodes_.find(front_node_with_index) !=
|
||||||
kernel_graph_group_info->front_output_nodes_.end()) {
|
kernel_graph_group_info->front_output_nodes_.end()) {
|
||||||
continue;
|
continue;
|
||||||
|
|
Loading…
Reference in New Issue