forked from mindspore-Ecosystem/mindspore
!46575 Fetch call node output type by abstract.
Merge pull request !46575 from gaoyong10/dynamic_shape_01
This commit is contained in:
commit
2cc4d8044f
|
@ -298,6 +298,8 @@ class COMMON_EXPORT AnfAlgo {
|
||||||
|
|
||||||
// Get the element shape of dynamic sequence shape.
|
// Get the element shape of dynamic sequence shape.
|
||||||
static abstract::BaseShapePtr GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx);
|
static abstract::BaseShapePtr GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx);
|
||||||
|
// Fetch the sub abstract from the top abstract by the index.
|
||||||
|
static abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
|
||||||
};
|
};
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -50,7 +50,7 @@ void StackActor::Init() {
|
||||||
MS_EXCEPTION_IF_NULL(formal_parameter.first);
|
MS_EXCEPTION_IF_NULL(formal_parameter.first);
|
||||||
const auto &abstract = formal_parameter.first->abstract();
|
const auto &abstract = formal_parameter.first->abstract();
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second);
|
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, formal_parameter.second);
|
||||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||||
total_partials_num++;
|
total_partials_num++;
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include "runtime/graph_scheduler/control_node_parser.h"
|
#include "runtime/graph_scheduler/control_node_parser.h"
|
||||||
#include "runtime/graph_scheduler/actor/actor_common.h"
|
#include "runtime/graph_scheduler/actor/actor_common.h"
|
||||||
|
@ -319,7 +320,7 @@ TypeId FetchTypeIdByNode(const AnfNodePtr &node, size_t index) {
|
||||||
TypeId type_id = kTypeUnknown;
|
TypeId type_id = kTypeUnknown;
|
||||||
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
|
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
|
||||||
// For valuenode, fetch type from abstract.
|
// For valuenode, fetch type from abstract.
|
||||||
const auto &abs = FetchAbstractByIndex(node->abstract(), index);
|
const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
|
||||||
MS_EXCEPTION_IF_NULL(abs);
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
const auto &type = abs->BuildType();
|
const auto &type = abs->BuildType();
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
|
@ -341,7 +342,7 @@ size_t FetchOutputSizeByNode(const AnfNodePtr &node, size_t index, TypeId type_i
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
size_t size = GetTypeByte(TypeIdToType(type_id));
|
size_t size = GetTypeByte(TypeIdToType(type_id));
|
||||||
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
|
if (node->isa<ValueNode>() && node->abstract() != nullptr) {
|
||||||
const auto &abs = FetchAbstractByIndex(node->abstract(), index);
|
const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), index);
|
||||||
MS_EXCEPTION_IF_NULL(abs);
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
const auto &shape_ptr = abs->BuildShape();
|
const auto &shape_ptr = abs->BuildShape();
|
||||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||||
|
@ -738,30 +739,6 @@ std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) {
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
|
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
|
||||||
if (!abstract->isa<abstract::AbstractSequence>() || abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
|
|
||||||
if (index != 0) {
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
|
||||||
}
|
|
||||||
return abstract;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
|
||||||
const auto &sub_abstracts = tuple_abstract->elements();
|
|
||||||
size_t real_index = index;
|
|
||||||
for (const auto &sub_abstract : sub_abstracts) {
|
|
||||||
size_t tmp_index = common::AnfAlgo::GetOutputNumByAbstract(sub_abstract);
|
|
||||||
if (real_index >= tmp_index) {
|
|
||||||
real_index -= tmp_index;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return FetchAbstractByIndex(sub_abstract, real_index);
|
|
||||||
}
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsPartialInput(const AnfNodePtr &node) {
|
bool IsPartialInput(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
const auto &abstract = node->abstract();
|
const auto &abstract = node->abstract();
|
||||||
|
@ -2262,7 +2239,7 @@ void CollectEffectiveOutputByGraph(const KernelGraphPtr &graph, DeviceContext *c
|
||||||
// Skip the function input.
|
// Skip the function input.
|
||||||
const auto &abstract = backend_to_front.second.first->abstract();
|
const auto &abstract = backend_to_front.second.first->abstract();
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
const auto &real_abstract = FetchAbstractByIndex(abstract, backend_to_front.second.second);
|
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, backend_to_front.second.second);
|
||||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -115,8 +115,6 @@ bool IsCooNode(const AnfNodePtr &node);
|
||||||
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraph *const graph);
|
KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, const KernelGraph *const graph);
|
||||||
// Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem.
|
// Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem.
|
||||||
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node);
|
std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node);
|
||||||
// Fetch the sub abstract from the top abstract by the index.
|
|
||||||
abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
|
|
||||||
// Fetch the real input of tuple get item node.
|
// Fetch the real input of tuple get item node.
|
||||||
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index);
|
KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index);
|
||||||
// Check if the partial node is valid.
|
// Check if the partial node is valid.
|
||||||
|
|
|
@ -535,7 +535,7 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo
|
||||||
|
|
||||||
const auto &abstract = parameter.first->abstract();
|
const auto &abstract = parameter.first->abstract();
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
const auto &real_abstract = FetchAbstractByIndex(abstract, parameter.second);
|
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, parameter.second);
|
||||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||||
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
if (real_abstract->isa<abstract::AbstractFunction>()) {
|
||||||
input_parameter_partials_num++;
|
input_parameter_partials_num++;
|
||||||
|
@ -772,7 +772,7 @@ void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor
|
||||||
size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
|
size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
|
||||||
const auto &abstract = from_node->abstract();
|
const auto &abstract = from_node->abstract();
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second);
|
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, formal_parameter.second);
|
||||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||||
|
|
||||||
// Link arrow according to abstract.
|
// Link arrow according to abstract.
|
||||||
|
@ -912,7 +912,7 @@ void ControlNodeScheduler::LinkArrowByParameter(const AnfNodePtr ¶meter, Con
|
||||||
|
|
||||||
auto abstract = parameter->abstract();
|
auto abstract = parameter->abstract();
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
auto dst_abstract = FetchAbstractByIndex(abstract, from_node_with_index.second);
|
auto dst_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, from_node_with_index.second);
|
||||||
MS_EXCEPTION_IF_NULL(dst_abstract);
|
MS_EXCEPTION_IF_NULL(dst_abstract);
|
||||||
if (dst_abstract->isa<abstract::AbstractFunction>()) {
|
if (dst_abstract->isa<abstract::AbstractFunction>()) {
|
||||||
SchedulerHelper::AddPartialArrow(from_actor, to_actor, from_actor->FetchNodePosition(from_node_with_index),
|
SchedulerHelper::AddPartialArrow(from_actor, to_actor, from_actor->FetchNodePosition(from_node_with_index),
|
||||||
|
@ -937,7 +937,7 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont
|
||||||
// Link arrow from exit actor to control actor.
|
// Link arrow from exit actor to control actor.
|
||||||
const auto &abstract = call_node->abstract();
|
const auto &abstract = call_node->abstract();
|
||||||
MS_EXCEPTION_IF_NULL(abstract);
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
const auto &real_abstract = FetchAbstractByIndex(abstract, from_node_with_index.second);
|
const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, from_node_with_index.second);
|
||||||
MS_EXCEPTION_IF_NULL(real_abstract);
|
MS_EXCEPTION_IF_NULL(real_abstract);
|
||||||
|
|
||||||
const auto &func_graphs = parser->FetchFuncGraphbyCallNode(from_node);
|
const auto &func_graphs = parser->FetchFuncGraphbyCallNode(from_node);
|
||||||
|
|
|
@ -751,6 +751,23 @@ TypeId AnfAlgo::GetOutputInferDataType(const TypePtr &type, size_t output_idx) {
|
||||||
|
|
||||||
TypeId AnfAlgo::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
|
TypeId AnfAlgo::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (IsCallNode(node)) {
|
||||||
|
if (node->abstract() == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Empty abstract of call node:" << node->DebugString();
|
||||||
|
}
|
||||||
|
const auto &abs = common::AnfAlgo::FetchAbstractByIndex(node->abstract(), output_idx);
|
||||||
|
MS_EXCEPTION_IF_NULL(abs);
|
||||||
|
const auto &type = abs->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
|
if (type->isa<TensorType>()) {
|
||||||
|
const auto &tensor_type = type->cast<TensorTypePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||||
|
const auto &element = tensor_type->element();
|
||||||
|
return element->type_id();
|
||||||
|
} else {
|
||||||
|
return type->type_id();
|
||||||
|
}
|
||||||
|
}
|
||||||
return GetOutputInferDataType(node->Type(), output_idx);
|
return GetOutputInferDataType(node->Type(), output_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1949,5 +1966,29 @@ abstract::BaseShapePtr AnfAlgo::GetDynamicSequenceShape(const AnfNodePtr &node,
|
||||||
}
|
}
|
||||||
return element_abs->BuildShape();
|
return element_abs->BuildShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr AnfAlgo::FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) {
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
if (!abstract->isa<abstract::AbstractSequence>() || abstract->cast<abstract::AbstractSequencePtr>()->dynamic_len()) {
|
||||||
|
if (index != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||||
|
}
|
||||||
|
return abstract;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||||
|
const auto &sub_abstracts = tuple_abstract->elements();
|
||||||
|
size_t real_index = index;
|
||||||
|
for (const auto &sub_abstract : sub_abstracts) {
|
||||||
|
size_t tmp_index = common::AnfAlgo::GetOutputNumByAbstract(sub_abstract);
|
||||||
|
if (real_index >= tmp_index) {
|
||||||
|
real_index -= tmp_index;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return FetchAbstractByIndex(sub_abstract, real_index);
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString();
|
||||||
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue