!26112 Move functions from AnfAlgo to AnfUtils

Merge pull request !26112 from DeshiChen/1105_anfalgo
This commit is contained in:
i-robot 2021-11-15 01:53:23 +00:00 committed by Gitee
commit 5e12f336f2
3 changed files with 159 additions and 124 deletions

View File

@ -213,43 +213,8 @@ size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_it
}
KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
if (anf_node->isa<ValueNode>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<Parameter>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input0 = cnode->input(0);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
if (AnfAlgo::GetInputTensorNum(cnode) == 0) {
return std::make_pair(nullptr, 0);
}
auto node = cnode->input(index + IntToSize(1));
MS_EXCEPTION_IF_NULL(node);
return VisitKernel(node, 0);
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(input2);
auto value_node = input2->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto item_idx = GetValue<int64_t>(value_node->value());
return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), LongToSize(item_idx));
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
return VisitKernel(cnode->input(kUpdateStateRealInput), 0);
} else if (IsOneOfPrimitive(input0, follow_first_input_prims)) {
return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
} else {
return std::make_pair(anf_node, index);
}
} else {
MS_LOG(EXCEPTION) << "The input is invalid";
}
// this function was moved to AnfUtils.
return AnfUtils::VisitKernel(anf_node, index);
}
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
@ -489,28 +454,8 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetCNodeFuncGraphPtr(const AnfNodePtr &node) {
}
std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto primitive = AnfAlgo::GetCNodePrimitive(node);
if (primitive != nullptr) {
if (primitive->name() == "Custom") {
auto uniq_name = primitive->GetAttr("uniq_name");
if (uniq_name) {
return GetValue<std::string>(uniq_name);
}
}
return primitive->name();
}
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string fg_name = "GraphKernel_";
fg_name += GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
return fg_name;
}
return func_graph->ToString();
}
MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString() << " trace: " << trace::DumpSourceLines(node);
// this function was moved to AnfUtils.
return AnfUtils::GetCNodeName(node);
}
std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) {
@ -614,55 +559,13 @@ size_t AnfRuntimeAlgorithm::GetInputNum(const CNodePtr &cnode) {
}
size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString()
<< " trace: " << trace::DumpSourceLines(node);
}
ssize_t input_tensor_num = cnode->input_tensor_num();
if (input_tensor_num >= 0) {
return static_cast<size_t>(input_tensor_num);
}
size_t input_num = cnode->inputs().size();
if (input_num == 0) {
MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero"
<< " trace: " << trace::DumpSourceLines(node);
}
// Exclude inputs[0].
--input_num;
// Exclude monad inputs for real cnodes.
if (input_num > 0 && AnfUtils::IsRealKernel(cnode)) {
auto &inputs = cnode->inputs();
// Search monad inputs, backward.
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
if (!HasAbstractMonad(*iter)) {
// Stop count if we encounter a non-monad input.
break;
}
--input_num;
}
}
cnode->set_input_tensor_num(static_cast<ssize_t>(input_num));
return input_num;
// this function was moved to AnfUtils.
return AnfUtils::GetInputTensorNum(node);
}
size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
TypePtr type = node->Type();
if (type == nullptr) {
return 0;
}
if (type->isa<Tuple>()) {
auto tuple_type = type->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_type);
return tuple_type->size();
}
if (type->isa<TypeNone>()) {
return 0;
}
return 1;
// this function was moved to AnfUtils.
return AnfUtils::GetOutputTensorNum(node);
}
size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index) {
@ -1431,28 +1334,13 @@ void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *
}
bool AnfRuntimeAlgorithm::IsGraphKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// graph kernel should be a real cnode kernel.
if (!AnfUtils::IsRealCNodeKernel(node)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input = cnode->input(kAnfPrimitiveIndex);
// graph kernel should has func_graph as first input.
if (!IsValueNode<FuncGraph>(input)) {
return false;
}
auto func_graph = GetValueNode<FuncGraphPtr>(input);
MS_EXCEPTION_IF_NULL(func_graph);
return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
// this function was moved to AnfUtils.
return AnfUtils::IsGraphKernel(node);
}
bool AnfRuntimeAlgorithm::IsNodeInGraphKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return node->func_graph() != nullptr && node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
// this function was moved to AnfUtils.
return AnfUtils::IsNodeInGraphKernel(node);
}
AnfNodePtr AnfRuntimeAlgorithm::GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index) {

View File

@ -15,11 +15,16 @@
*/
#include "utils/anf_utils.h"
#include <string>
#include "base/core_ops.h"
#include "utils/trace_base.h"
#include "utils/utils.h"
namespace mindspore {
namespace {
const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
} // namespace
bool AnfUtils::IsDimUnknown(const abstract::ShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < -1; });
@ -117,4 +122,132 @@ bool AnfUtils::IsRealCNodeKernel(const AnfNodePtr &node) {
}
return AnfUtils::IsRealKernel(node);
}
std::string AnfUtils::GetCNodeName(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto primitive = GetCNodePrimitive(node);
if (primitive != nullptr) {
if (primitive->name() == "Custom") {
auto uniq_name = primitive->GetAttr("uniq_name");
if (uniq_name) {
return GetValue<std::string>(uniq_name);
}
}
return primitive->name();
}
auto func_graph = GetCNodeFuncGraph(node);
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string fg_name = "GraphKernel_";
fg_name += GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
return fg_name;
}
return func_graph->ToString();
}
MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString() << " trace: " << trace::DumpSourceLines(node);
}
size_t AnfUtils::GetInputTensorNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString()
<< " trace: " << trace::DumpSourceLines(node);
}
ssize_t input_tensor_num = cnode->input_tensor_num();
if (input_tensor_num >= 0) {
return static_cast<size_t>(input_tensor_num);
}
size_t input_num = cnode->inputs().size();
if (input_num == 0) {
MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero"
<< " trace: " << trace::DumpSourceLines(node);
}
// Exclude inputs[0].
--input_num;
// Exclude monad inputs for real cnodes.
if (input_num > 0 && AnfUtils::IsRealKernel(cnode)) {
auto &inputs = cnode->inputs();
// Search monad inputs, backward.
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
if (!HasAbstractMonad(*iter)) {
// Stop count if we encounter a non-monad input.
break;
}
--input_num;
}
}
cnode->set_input_tensor_num(static_cast<ssize_t>(input_num));
return input_num;
}
size_t AnfUtils::GetOutputTensorNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
TypePtr type = node->Type();
if (type == nullptr) {
return 0;
}
if (type->isa<Tuple>()) {
auto tuple_type = type->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_type);
return tuple_type->size();
}
if (type->isa<TypeNone>()) {
return 0;
}
return 1;
}
std::pair<AnfNodePtr, size_t> AnfUtils::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
if (anf_node->isa<ValueNode>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<Parameter>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input0 = cnode->input(0);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
if (GetInputTensorNum(cnode) == 0) {
return std::make_pair(nullptr, 0);
}
auto node = cnode->input(index + IntToSize(1));
MS_EXCEPTION_IF_NULL(node);
return VisitKernel(node, 0);
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(input2);
auto value_node = input2->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto item_idx = GetValue<int64_t>(value_node->value());
return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), LongToSize(item_idx));
} else if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
return VisitKernel(cnode->input(kUpdateStateRealInput), 0);
} else if (IsOneOfPrimitive(input0, follow_first_input_prims)) {
return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
} else {
return std::make_pair(anf_node, index);
}
} else {
MS_LOG(EXCEPTION) << "The input is invalid";
}
}
bool AnfUtils::IsGraphKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto func_graph = GetCNodeFuncGraph(node);
return func_graph != nullptr && func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}
bool AnfUtils::IsNodeInGraphKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return node->func_graph() != nullptr && node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}
} // namespace mindspore

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_CORE_UTILS_ANF_UTILS_H_
#define MINDSPORE_CORE_UTILS_ANF_UTILS_H_
#include <vector>
#include <string>
#include <utility>
#include "ir/anf.h"
#include "ir/dtype.h"
#include "base/base.h"
@ -34,6 +36,18 @@ class AnfUtils {
static bool IsRealKernel(const AnfNodePtr &node);
// check whether the anf node is a real kernel that is a cnode and can run on device
static bool IsRealCNodeKernel(const AnfNodePtr &node);
// get kernel name of anf node
static std::string GetCNodeName(const AnfNodePtr &node);
// get the num of inputs exclude monads for real_kernel (which can be build and run in device)
static size_t GetInputTensorNum(const AnfNodePtr &node);
// get the num of output real_kernel(which can be build and run in device)
static size_t GetOutputTensorNum(const AnfNodePtr &node);
// get the node's real kernel recursively
static std::pair<AnfNodePtr, size_t> VisitKernel(const AnfNodePtr &anf_node, size_t index);
// check whether the node is a GraphKernel node.
static bool IsGraphKernel(const AnfNodePtr &node);
// check whether the node is a node in GraphKernel's subgraph.
static bool IsNodeInGraphKernel(const AnfNodePtr &node);
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_ANF_UTILS_H_