forked from mindspore-Ecosystem/mindspore
!28387 Support getitem in getitem.
Merge pull request !28387 from gaoyong10/runtime_second12
This commit is contained in:
commit
a8ecbb4ef8
|
@ -206,7 +206,7 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
|
|||
}
|
||||
|
||||
// If the node is a call, the outputs num should get from the abstract.
|
||||
if (AnfAlgo::IsCallNode(node)) {
|
||||
if (AnfAlgo::IsCallNode(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
auto abstract = node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
outputs_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
|
@ -315,7 +315,7 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
|
|||
return KernelWithIndex(anf_node, index);
|
||||
}
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
return KernelWithIndex(anf_node, 0);
|
||||
return KernelWithIndex(anf_node, index);
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -2634,5 +2634,45 @@ int64_t AnfRuntimeAlgorithm::GetAttrGroups(const AnfNodePtr &node, size_t index)
|
|||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
AnfNodePtr AnfRuntimeAlgorithm::GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *index_stack) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(index_stack);
|
||||
|
||||
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_getitem = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
// Get cur index
|
||||
auto output_index_value_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(output_index_value_node);
|
||||
auto value_node = output_index_value_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto output_idx = LongToSize(GetValue<int64_t>(value_node->value()));
|
||||
index_stack->push_back(output_idx);
|
||||
auto real_input = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
|
||||
return GetTupleIndexes(real_input, index_stack);
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
// If make_tuple in make_tuple, visit may start with inner tuple_getitem.
|
||||
if (index_stack->empty()) {
|
||||
MS_LOG(WARNING) << "Visit make tuple: " << node->DebugString()
|
||||
<< ", but index are empty, visit should not start with inner tuple_getitem.";
|
||||
return nullptr;
|
||||
}
|
||||
auto make_tuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
auto output_idx = index_stack->back();
|
||||
index_stack->pop_back();
|
||||
return GetTupleIndexes(make_tuple->input(1 + output_idx), index_stack);
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
return GetTupleIndexes(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), index_stack);
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
return GetTupleIndexes(node->cast<CNodePtr>()->input(1), index_stack);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get real node:" << node->DebugString();
|
||||
return node;
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -364,6 +364,8 @@ class AnfRuntimeAlgorithm {
|
|||
}
|
||||
|
||||
static void UpdateGraphValidRefPair(const KernelGraphPtr &graph);
|
||||
// Get the real output node and indexes of get item, make tuple, depend, load.
|
||||
static AnfNodePtr GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *index_stack);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -223,10 +223,22 @@ KernelWithIndex FetchRealInputNode(const KernelWithIndex &node_with_index) {
|
|||
}
|
||||
|
||||
// Fetch all the output index in the sub-abstract of abstract.
|
||||
std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, size_t index) {
|
||||
std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, std::vector<size_t> *indexes) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
MS_EXCEPTION_IF_NULL(indexes);
|
||||
AbstractBasePtr dst_abstract = abstract;
|
||||
size_t pre_abstract_num = 0;
|
||||
std::set<size_t> output_indexs;
|
||||
if (indexes->empty()) {
|
||||
size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
output_indexs.emplace(i);
|
||||
}
|
||||
return output_indexs;
|
||||
}
|
||||
|
||||
size_t index = indexes->back();
|
||||
indexes->pop_back();
|
||||
|
||||
// Fetch the dest abstract by index, and the abstracts num before the dest abstract.
|
||||
if (abstract->isa<abstract::AbstractCSRTensor>()) {
|
||||
|
@ -272,12 +284,11 @@ std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, size_
|
|||
MS_EXCEPTION_IF_NULL(dst_abstract);
|
||||
|
||||
// Fetch real output index.
|
||||
size_t ouput_num = AnfAlgo::GetOutputNumByAbstract(dst_abstract);
|
||||
std::set<size_t> real_indexs;
|
||||
for (size_t i = pre_abstract_num; i < ouput_num + pre_abstract_num; ++i) {
|
||||
real_indexs.emplace(i);
|
||||
auto tmp_indexs = FetchRealIndexByAbstract(dst_abstract, indexes);
|
||||
for (auto tmp_index : tmp_indexs) {
|
||||
output_indexs.emplace(tmp_index + pre_abstract_num);
|
||||
}
|
||||
return real_indexs;
|
||||
return output_indexs;
|
||||
}
|
||||
|
||||
// Get all the real parameters corresponding to node.
|
||||
|
@ -603,7 +614,8 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
|||
// Csr node from parameter or call node.
|
||||
auto abstract = src_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto real_indexs = FetchRealIndexByAbstract(abstract, iter->second);
|
||||
std::vector<size_t> index_stack{LongToSize(iter->second)};
|
||||
auto real_indexs = FetchRealIndexByAbstract(abstract, &index_stack);
|
||||
(void)std::transform(real_indexs.begin(), real_indexs.end(), std::back_inserter(results),
|
||||
[&src_node](const auto &index) { return KernelWithIndex(src_node, index); });
|
||||
}
|
||||
|
@ -620,30 +632,19 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
|||
|
||||
// 5 Other.
|
||||
if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimTupleGetItem)) {
|
||||
const auto &get_item_cnode = real_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(get_item_cnode);
|
||||
const auto &get_item_src_node = AnfAlgo::GetTupleGetItemRealInput(get_item_cnode);
|
||||
size_t get_item_src_index = AnfAlgo::GetTupleGetItemOutIndex(get_item_cnode);
|
||||
|
||||
// Input node of getitm is a make tuple.
|
||||
if (AnfAlgo::CheckPrimitiveType(get_item_src_node, prim::kPrimMakeTuple)) {
|
||||
const auto &make_tuple_cnode = get_item_src_node->cast<CNodePtr>();
|
||||
const auto &makt_tuple_inputs = make_tuple_cnode->inputs();
|
||||
if (makt_tuple_inputs.size() <= get_item_src_index) {
|
||||
MS_LOG(EXCEPTION) << "Invalid index:" << get_item_src_index
|
||||
<< " for make tuple node : " << get_item_src_node->DebugString();
|
||||
}
|
||||
const auto &sub_results = FetchInputNodeByNode(makt_tuple_inputs[get_item_src_index + kMakeTupleInputStartPos]);
|
||||
std::vector<size_t> index_stack;
|
||||
auto get_item_src_node = AnfAlgo::GetTupleIndexes(real_node, &index_stack);
|
||||
MS_EXCEPTION_IF_NULL(get_item_src_node);
|
||||
if (index_stack.empty()) {
|
||||
const auto &sub_results = FetchInputNodeByNode(get_item_src_node);
|
||||
results.insert(results.end(), sub_results.begin(), sub_results.end());
|
||||
} else {
|
||||
// Input node of getitm is a parameter or make tuple.
|
||||
auto get_item_src_abstract = get_item_src_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(get_item_src_abstract);
|
||||
auto real_indexs = FetchRealIndexByAbstract(get_item_src_abstract, get_item_src_index);
|
||||
(void)std::transform(
|
||||
real_indexs.begin(), real_indexs.end(), std::back_inserter(results),
|
||||
[&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); });
|
||||
return results;
|
||||
}
|
||||
auto get_item_src_abstract = get_item_src_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(get_item_src_abstract);
|
||||
auto indexes = FetchRealIndexByAbstract(get_item_src_abstract, &index_stack);
|
||||
(void)std::transform(indexes.begin(), indexes.end(), std::back_inserter(results),
|
||||
[&get_item_src_node](const auto &index) { return KernelWithIndex(get_item_src_node, index); });
|
||||
return results;
|
||||
}
|
||||
|
||||
|
@ -690,6 +691,24 @@ void AddFormalToRealParameter(const AnfNodePtr &formal_parameter, const AnfNodeP
|
|||
}
|
||||
} // namespace
|
||||
|
||||
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
std::vector<size_t> index_stack{node_with_index.second};
|
||||
|
||||
const auto &get_item_src_node = AnfAlgo::GetTupleIndexes(node_with_index.first, &index_stack);
|
||||
const auto &get_item_src_abstract = get_item_src_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(get_item_src_abstract);
|
||||
auto indexes = FetchRealIndexByAbstract(get_item_src_abstract, &index_stack);
|
||||
if (indexes.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to find index for node:" << get_item_src_node;
|
||||
}
|
||||
if (indexes.size() > 1) {
|
||||
MS_LOG(WARNING) << "Output size:" << indexes.size() << " for node:" << get_item_src_node->DebugString()
|
||||
<< " more than 1";
|
||||
}
|
||||
return {get_item_src_node, *(indexes.begin())};
|
||||
}
|
||||
|
||||
bool HasAbstractRef(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
|
@ -1225,7 +1244,8 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(call_device_contexts[output_node.second]);
|
||||
return_device_contexts.emplace_back(call_device_contexts[output_node.second]);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimPartial)) {
|
||||
} else if (AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimPartial) ||
|
||||
AnfAlgo::CheckPrimitiveType(output_node.first, prim::kPrimSwitch)) {
|
||||
return_device_contexts.emplace_back(default_context);
|
||||
} else if (output_node.first->isa<CNode>()) {
|
||||
// If the output is a cnode, get the device context type by the kernel.
|
||||
|
@ -1820,7 +1840,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
|
|||
// Collect inputs in group.
|
||||
const auto &real_parameters = kernel_graph->input_nodes();
|
||||
for (const auto ¶meter : real_parameters) {
|
||||
const auto &front_node_with_index = GetFrontNodeByKernelGraph(parameter, kernel_graph.get());
|
||||
auto front_node_with_index = GetFrontNodeByKernelGraph(parameter, kernel_graph.get());
|
||||
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
|
||||
// If input come from the output of kernel graph belong the same group, it should not be collected in
|
||||
// the group inputs.
|
||||
|
@ -1832,6 +1852,12 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte
|
|||
if (AnfAlgo::IsCallNode(front_node_with_index.first)) {
|
||||
kernel_graph_group_info->is_call_input_ = true;
|
||||
}
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(front_node_with_index.first, prim::kPrimTupleGetItem)) {
|
||||
MS_LOG(WARNING) << "Input node:" << front_node_with_index.first->DebugString()
|
||||
<< " for graph:" << kernel_graph->ToString() << " is a tuple get item";
|
||||
front_node_with_index = FetchRealNodeByGetItem(front_node_with_index);
|
||||
}
|
||||
kernel_graph_group_info->front_input_nodes_[front_node_with_index] = iter->second;
|
||||
}
|
||||
|
||||
|
|
|
@ -103,6 +103,8 @@ KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, Kernel
|
|||
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node);
|
||||
// Fetch the sub abstract from the top abstract by the index.
|
||||
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
|
||||
// Fetch the real input of tuple get item node.
|
||||
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index);
|
||||
// ControlNodeParser is used to parse control nodes, and get the edges between nodes.
|
||||
class ControlNodeParser {
|
||||
public:
|
||||
|
|
|
@ -1128,6 +1128,13 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap
|
|||
if (from_node_with_index.first == nullptr) {
|
||||
from_node_with_index = tuple_node_with_index;
|
||||
}
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(from_node_with_index.first, prim::kPrimTupleGetItem)) {
|
||||
MS_LOG(WARNING) << "Input node:" << from_node_with_index.first->DebugString()
|
||||
<< " for graph:" << graph->ToString() << " is a tuple get item";
|
||||
from_node_with_index = FetchRealNodeByGetItem(from_node_with_index);
|
||||
}
|
||||
|
||||
// If the formal parameter is a tuple type, the parameter of the kernel graph will not directly correspond
|
||||
// to the front parameter, but the node in the internal parameter.
|
||||
const auto &from_node = from_node_with_index.first;
|
||||
|
|
Loading…
Reference in New Issue