forked from mindspore-Ecosystem/mindspore
!46575 Fetch call node output type by abstract.
Merge pull request !46575 from gaoyong10/dynamic_shape_01
This commit is contained in:
commit
2cc4d8044f
|
@ -298,6 +298,8 @@ class COMMON_EXPORT AnfAlgo {
|
|||
|
||||
// Get the element shape of dynamic sequence shape.
|
||||
static abstract::BaseShapePtr GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx);
|
||||
// Fetch the sub abstract from the top abstract by the index.
|
||||
static abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,7 +50,7 @@ void StackActor::Init() {
|
|||
MS_EXCEPTION_IF_NULL(formal_parameter.first);
|
||||
const auto &abstract = formal_parameter.first->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second);
|
||||
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, formal_parameter.second);
|
||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||
total_partials_num++;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include "runtime/graph_scheduler/control_node_parser.h"
|
||||
#include "runtime/graph_scheduler/actor/actor_common.h"
|
||||
|
@ -319,7 +320,7 @@ TypeId FetchTypeIdByNode(const AnfNodePtr &node, size_t index) {
|
|||
TypeId type_id = kTypeUnknown;
|
||||
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
|
||||
// For valuenode, fetch type from abstract.
|
||||
const auto &abs = FetchAbstractByIndex(node->abstract(), index);
|
||||
const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
const auto &type = abs->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
@ -341,7 +342,7 @@ size_t FetchOutputSizeByNode(const AnfNodePtr &node, size_t index, TypeId type_i
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
size_t size = GetTypeByte(TypeIdToType(type_id));
|
||||
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
|
||||
const auto &abs = FetchAbstractByIndex(node->abstract(), index);
|
||||
const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
const auto &shape_ptr = abs->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
|
@ -738,30 +739,6 @@ std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
|
|||
return results;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (!abstract->isa<abstract::AbstractSequence>() || abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
|
||||
if (index != 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
return abstract;
|
||||
}
|
||||
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
size_t real_index = index;
|
||||
for (const auto &sub_abstract : sub_abstracts) {
|
||||
size_t tmp_index = common::AnfAlgo::GetOutputNumByAbstract(sub_abstract);
|
||||
if (real_index >= tmp_index) {
|
||||
real_index -= tmp_index;
|
||||
continue;
|
||||
}
|
||||
return FetchAbstractByIndex(sub_abstract, real_index);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
|
||||
bool IsPartialInput(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &abstract = node->abstract();
|
||||
|
@ -2262,7 +2239,7 @@ void CollectEffectiveOutputByGraph(const KernelGraphPtr &graph, DeviceContext *c
|
|||
// Skip the function input.
|
||||
const auto &abstract = backend_to_front.second.first->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
const auto &real_abstract = FetchAbstractByIndex(abstract, backend_to_front.second.second);
|
||||
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, backend_to_front.second.second);
|
||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||
continue;
|
||||
|
|
|
@ -115,8 +115,6 @@ bool IsCooNode(const AnfNodePtr &node);
|
|||
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraph *const graph);
|
||||
// Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem.
|
||||
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node);
|
||||
// Fetch the sub abstract from the top abstract by the index.
|
||||
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
|
||||
// Fetch the real input of tuple get item node.
|
||||
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index);
|
||||
// Check if the partial node is valid.
|
||||
|
|
|
@ -535,7 +535,7 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
|
|||
|
||||
const auto &abstract = parameter.first->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
const auto &real_abstract = FetchAbstractByIndex(abstract, parameter.second);
|
||||
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, parameter.second);
|
||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||
input_parameter_partials_num++;
|
||||
|
@ -772,7 +772,7 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor
|
|||
size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
|
||||
const auto &abstract = from_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second);
|
||||
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, formal_parameter.second);
|
||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||
|
||||
// Link arrow according to abstract.
|
||||
|
@ -912,7 +912,7 @@ void ControlNodeScheduler::LinkArrowByParameter(const AnfNodePtr ¶meter, Con
|
|||
|
||||
auto abstract = parameter->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto dst_abstract = FetchAbstractByIndex(abstract, from_node_with_index.second);
|
||||
auto dst_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, from_node_with_index.second);
|
||||
MS_EXCEPTION_IF_NULL(dst_abstract);
|
||||
if (dst_abstract->isa<abstract::AbstractFunction>()) {
|
||||
SchedulerHelper::AddPartialArrow(from_actor, to_actor, from_actor->FetchNodePosition(from_node_with_index),
|
||||
|
@ -937,7 +937,7 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
|
|||
// Link arrow from exit actor to control actor.
|
||||
const auto &abstract = call_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
const auto &real_abstract = FetchAbstractByIndex(abstract, from_node_with_index.second);
|
||||
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, from_node_with_index.second);
|
||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||
|
||||
const auto &func_graphs = parser->FetchFuncGraphbyCallNode(from_node);
|
||||
|
|
|
@ -751,6 +751,23 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
|
|||
|
||||
TypeId AnfAlgo::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsCallNode(node)) {
|
||||
if (node->abstract() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Empty abstract of call node:" << node->DebugString();
|
||||
}
|
||||
const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), output_idx);
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
const auto &type = abs->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->isa<TensorType>()) {
|
||||
const auto &tensor_type = type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
const auto &element = tensor_type->element();
|
||||
return element->type_id();
|
||||
} else {
|
||||
return type->type_id();
|
||||
}
|
||||
}
|
||||
return GetOutputInferDataType(node->Type(), output_idx);
|
||||
}
|
||||
|
||||
|
@ -1949,5 +1966,29 @@ abstract::BaseShapePtr AnfAlgo::GetDynamicSequenceShape(const AnfNodePtr &node,
|
|||
}
|
||||
return element_abs->BuildShape();
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr AnfAlgo::FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (!abstract->isa<abstract::AbstractSequence>() || abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
|
||||
if (index != 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
return abstract;
|
||||
}
|
||||
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
const auto &sub_abstracts = tuple_abstract->elements();
|
||||
size_t real_index = index;
|
||||
for (const auto &sub_abstract : sub_abstracts) {
|
||||
size_t tmp_index = common::AnfAlgo::GetOutputNumByAbstract(sub_abstract);
|
||||
if (real_index >= tmp_index) {
|
||||
real_index -= tmp_index;
|
||||
continue;
|
||||
}
|
||||
return FetchAbstractByIndex(sub_abstract, real_index);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue