!46575 Fetch call node output type by abstract.

Merge pull request !46575 from gaoyong10/dynamic_shape_01
This commit is contained in:
i-robot 2023-03-07 11:32:36 +00:00 committed by Gitee
commit 2cc4d8044f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 52 additions and 34 deletions

View File

@ -298,6 +298,8 @@ class COMMON_EXPORT AnfAlgo {
// Get the element shape of dynamic sequence shape.
static abstract::BaseShapePtr GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx);
// Fetch the sub abstract from the top abstract by the index.
static abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
};
} // namespace common
} // namespace mindspore

View File

@ -50,7 +50,7 @@ void StackActor::Init() {
MS_EXCEPTION_IF_NULL(formal_parameter.first);
const auto &abstract = formal_parameter.first->abstract();
MS_EXCEPTION_IF_NULL(abstract);
const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second);
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, formal_parameter.second);
MS_EXCEPTION_IF_NULL(real_abstract);
if (real_abstract->isa<abstract::AbstractFunction>()) {
total_partials_num++;

View File

@ -15,6 +15,7 @@
*/
#include <unordered_map>
#include <functional>
#include <map>
#include "runtime/graph_scheduler/control_node_parser.h"
#include "runtime/graph_scheduler/actor/actor_common.h"
@ -319,7 +320,7 @@ TypeId FetchTypeIdByNode(const AnfNodePtr &node, size_t index) {
TypeId type_id = kTypeUnknown;
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
// For valuenode, fetch type from abstract.
const auto &abs = FetchAbstractByIndex(node->abstract(), index);
const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
MS_EXCEPTION_IF_NULL(abs);
const auto &type = abs->BuildType();
MS_EXCEPTION_IF_NULL(type);
@ -341,7 +342,7 @@ size_t FetchOutputSizeByNode(const AnfNodePtr &node, size_t index, TypeId type_i
MS_EXCEPTION_IF_NULL(node);
size_t size = GetTypeByte(TypeIdToType(type_id));
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
const auto &abs = FetchAbstractByIndex(node->abstract(), index);
const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
MS_EXCEPTION_IF_NULL(abs);
const auto &shape_ptr = abs->BuildShape();
MS_EXCEPTION_IF_NULL(shape_ptr);
@ -738,30 +739,6 @@ std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
return results;
}
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
MS_EXCEPTION_IF_NULL(abstract);
if (!abstract->isa<abstract::AbstractSequence>() || abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
if (index != 0) {
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
}
return abstract;
}
auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
size_t real_index = index;
for (const auto &sub_abstract : sub_abstracts) {
size_t tmp_index = common::AnfAlgo::GetOutputNumByAbstract(sub_abstract);
if (real_index >= tmp_index) {
real_index -= tmp_index;
continue;
}
return FetchAbstractByIndex(sub_abstract, real_index);
}
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
}
bool IsPartialInput(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &abstract = node->abstract();
@ -2262,7 +2239,7 @@ void CollectEffectiveOutputByGraph(const KernelGraphPtr &graph, DeviceContext *c
// Skip the function input.
const auto &abstract = backend_to_front.second.first->abstract();
MS_EXCEPTION_IF_NULL(abstract);
const auto &real_abstract = FetchAbstractByIndex(abstract, backend_to_front.second.second);
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, backend_to_front.second.second);
MS_EXCEPTION_IF_NULL(real_abstract);
if (real_abstract->isa<abstract::AbstractFunction>()) {
continue;

View File

@ -115,8 +115,6 @@ bool IsCooNode(const AnfNodePtr &node);
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraph *const graph);
// Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem.
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node);
// Fetch the sub abstract from the top abstract by the index.
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
// Fetch the real input of tuple get item node.
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index);
// Check if the partial node is valid.

View File

@ -535,7 +535,7 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
const auto &abstract = parameter.first->abstract();
MS_EXCEPTION_IF_NULL(abstract);
const auto &real_abstract = FetchAbstractByIndex(abstract, parameter.second);
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, parameter.second);
MS_EXCEPTION_IF_NULL(real_abstract);
if (real_abstract->isa<abstract::AbstractFunction>()) {
input_parameter_partials_num++;
@ -772,7 +772,7 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor
size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
const auto &abstract = from_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second);
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, formal_parameter.second);
MS_EXCEPTION_IF_NULL(real_abstract);
// Link arrow according to abstract.
@ -912,7 +912,7 @@ void ControlNodeScheduler::LinkArrowByParameter(const AnfNodePtr &parameter, Con
auto abstract = parameter->abstract();
MS_EXCEPTION_IF_NULL(abstract);
auto dst_abstract = FetchAbstractByIndex(abstract, from_node_with_index.second);
auto dst_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, from_node_with_index.second);
MS_EXCEPTION_IF_NULL(dst_abstract);
if (dst_abstract->isa<abstract::AbstractFunction>()) {
SchedulerHelper::AddPartialArrow(from_actor, to_actor, from_actor->FetchNodePosition(from_node_with_index),
@ -937,7 +937,7 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
// Link arrow from exit actor to control actor.
const auto &abstract = call_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
const auto &real_abstract = FetchAbstractByIndex(abstract, from_node_with_index.second);
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, from_node_with_index.second);
MS_EXCEPTION_IF_NULL(real_abstract);
const auto &func_graphs = parser->FetchFuncGraphbyCallNode(from_node);

View File

@ -751,6 +751,23 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
TypeId AnfAlgo::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (IsCallNode(node)) {
if (node->abstract() == nullptr) {
MS_LOG(EXCEPTION) << "Empty abstract of call node:" << node->DebugString();
}
const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), output_idx);
MS_EXCEPTION_IF_NULL(abs);
const auto &type = abs->BuildType();
MS_EXCEPTION_IF_NULL(type);
if (type->isa<TensorType>()) {
const auto &tensor_type = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
const auto &element = tensor_type->element();
return element->type_id();
} else {
return type->type_id();
}
}
return GetOutputInferDataType(node->Type(), output_idx);
}
@ -1949,5 +1966,29 @@ abstract::BaseShapePtr AnfAlgo::GetDynamicSequenceShape(const AnfNodePtr &node,
}
return element_abs->BuildShape();
}
abstract::AbstractBasePtr AnfAlgo::FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
MS_EXCEPTION_IF_NULL(abstract);
if (!abstract->isa<abstract::AbstractSequence>() || abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
if (index != 0) {
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
}
return abstract;
}
auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
const auto &sub_abstracts = tuple_abstract->elements();
size_t real_index = index;
for (const auto &sub_abstract : sub_abstracts) {
size_t tmp_index = common::AnfAlgo::GetOutputNumByAbstract(sub_abstract);
if (real_index >= tmp_index) {
real_index -= tmp_index;
continue;
}
return FetchAbstractByIndex(sub_abstract, real_index);
}
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
}
} // namespace common
} // namespace mindspore