forked from mindspore-Ecosystem/mindspore
!25323 Get output nodes by control node.
Merge pull request !25323 from gaoyong10/for_d_controlflow
This commit is contained in:
commit
37199b282e
|
@ -49,6 +49,11 @@ namespace {
|
|||
constexpr size_t kNopNodeInputSize = 2;
|
||||
constexpr size_t kNopNodeRealInputIndex = 1;
|
||||
constexpr size_t kReturnDataIndex = 1;
|
||||
constexpr size_t kSwitchTrueBranchIndex = 2;
|
||||
constexpr size_t kPartialFuncGraphPos = 1;
|
||||
constexpr size_t kSwitchLayerBranchPos = 2;
|
||||
constexpr size_t kSwitchTrueBranchPos = 2;
|
||||
constexpr size_t kMakeTupleInputStartPos = 1;
|
||||
|
||||
const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
|
||||
|
||||
|
@ -142,6 +147,54 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
|
|||
return inputs->push_back(std::make_pair(node, output_index));
|
||||
}
|
||||
|
||||
// Fetch all outputs of control nodes, visited nodes indicates the call node that has been processed. In control flow,
|
||||
// there are recursive calls between funcgraphs, so the processed call nodes are recorded to prevent infinite loops.
|
||||
std::vector<KernelWithIndex> GetAllOutputByControlFlowNode(const KernelWithIndex &output_with_index,
|
||||
std::set<AnfNodePtr> *visited_call_nodes) {
|
||||
std::vector<KernelWithIndex> ret;
|
||||
const auto &node = output_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
|
||||
const auto &switch_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_cnode);
|
||||
const auto &switch_inputs = switch_cnode->inputs();
|
||||
auto output_vector = AnfAlgo::GetAllOutputWithIndex(switch_inputs[kSwitchTrueBranchIndex], visited_call_nodes);
|
||||
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
|
||||
} else if (AnfAlgo::IsCallNode(node)) {
|
||||
if (visited_call_nodes != nullptr) {
|
||||
if (visited_call_nodes->find(node) != visited_call_nodes->end()) {
|
||||
return ret;
|
||||
} else {
|
||||
visited_call_nodes->emplace(node);
|
||||
}
|
||||
}
|
||||
|
||||
// The output of the call node is the output of the funcgraph actually called.
|
||||
const auto &func_graphs = AnfAlgo::GetFuncGraphbyCallNode(node);
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
// The call in the graph kernel does not need to be parsed, and the node is directly output.
|
||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
ret.emplace_back(output_with_index);
|
||||
break;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(func_graph->output());
|
||||
const auto &func_graph_output =
|
||||
AnfAlgo::VisitKernelWithReturnType(func_graph->output(), output_with_index.second);
|
||||
std::set<AnfNodePtr> tmp_visited_nodes = {node};
|
||||
auto output_vector = AnfAlgo::GetAllOutputWithIndex(
|
||||
func_graph_output.first, (visited_call_nodes == nullptr ? &tmp_visited_nodes : visited_call_nodes));
|
||||
if (output_with_index.second < output_vector.size()) {
|
||||
ret.emplace_back(output_vector[output_with_index.second]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
// ops pair that dynamic input order is differ from the fixed shape ops
|
||||
// pair: <real_input->ori_input, ori_input->real_input>
|
||||
static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = {
|
||||
|
@ -339,7 +392,8 @@ std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) {
|
||||
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node,
|
||||
std::set<AnfNodePtr> *visited_call_nodes) {
|
||||
std::vector<KernelWithIndex> ret;
|
||||
std::vector<KernelWithIndex> ret_empty;
|
||||
|
||||
|
@ -348,7 +402,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
|||
auto make_tuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
|
||||
auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i));
|
||||
auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i), visited_call_nodes);
|
||||
(void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
|
||||
}
|
||||
return ret;
|
||||
|
@ -358,7 +412,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
|||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
||||
auto depend_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend));
|
||||
auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend), visited_call_nodes);
|
||||
(void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
|
||||
return ret;
|
||||
}
|
||||
|
@ -393,20 +447,16 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
|||
|
||||
// The makeTuple node need recurse.
|
||||
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
|
||||
auto output_vector = GetAllOutputWithIndex(output_with_index.first);
|
||||
auto output_vector = GetAllOutputWithIndex(output_with_index.first, visited_call_nodes);
|
||||
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ignore the output of front call node.
|
||||
if (output_with_index.first->isa<CNode>()) {
|
||||
auto cnode = output_with_index.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs[0]->isa<CNode>()) {
|
||||
MS_LOG(INFO) << "The output is call node: " << output_with_index.first->DebugString();
|
||||
return ret_empty;
|
||||
}
|
||||
// Fetch outputs by control nodes.
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || AnfAlgo::IsCallNode(node)) {
|
||||
const auto &control_node_output = GetAllOutputByControlFlowNode(output_with_index, visited_call_nodes);
|
||||
(void)std::copy(control_node_output.begin(), control_node_output.end(), std::back_inserter(ret));
|
||||
continue;
|
||||
}
|
||||
|
||||
// The InitDataSetQueue node has no output.
|
||||
|
@ -2526,5 +2576,100 @@ size_t OpRuntimeInfo::output_tensor_size(size_t index) const {
|
|||
}
|
||||
return output_tensor_size_[index];
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsCallNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.empty() || inputs[0] == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Invalid call node:" << node->DebugString();
|
||||
}
|
||||
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
|
||||
}
|
||||
|
||||
std::set<FuncGraphPtr> AnfRuntimeAlgorithm::GetFuncGraphbyCallNode(const AnfNodePtr &node, size_t call_depth) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::set<FuncGraphPtr> func_graphs;
|
||||
if (!node->isa<CNode>()) {
|
||||
return func_graphs;
|
||||
}
|
||||
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &call_input0 = cnode->input(0);
|
||||
MS_EXCEPTION_IF_NULL(call_input0);
|
||||
|
||||
if (AnfAlgo::IsCallNode(call_input0)) {
|
||||
return AnfAlgo::GetFuncGraphbyCallNode(call_input0, ++call_depth);
|
||||
}
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitch)) {
|
||||
// First input node of call is switch node.
|
||||
const auto &switch_inputs = call_input0->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = kSwitchTrueBranchPos; i < switch_inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(switch_inputs[i]);
|
||||
(void)func_graphs.emplace(GetFuncGraphFromPartial(switch_inputs[i], call_depth));
|
||||
}
|
||||
} else if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitchLayer)) {
|
||||
// First input node of call is switch layer node.
|
||||
const auto &tuple_node = cnode->cast<CNodePtr>()->input(kSwitchLayerBranchPos);
|
||||
if (!AnfAlgo::CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input tuple node:" << tuple_node->DebugString()
|
||||
<< " for switch layer node:" << cnode->DebugString();
|
||||
}
|
||||
|
||||
const auto &tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_inputs[i]);
|
||||
func_graphs.emplace(GetFuncGraphFromPartial(tuple_inputs[i], call_depth));
|
||||
}
|
||||
} else if (IsPartial(call_input0)) {
|
||||
// First input node of call is partial node or value node of funcgraph.
|
||||
(void)func_graphs.emplace(GetFuncGraphFromPartial(call_input0, call_depth));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
|
||||
}
|
||||
return func_graphs;
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsPartial(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial);
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfRuntimeAlgorithm::GetFuncGraphFromPartial(const AnfNodePtr &node, size_t depth) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (depth == 1) {
|
||||
if (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) {
|
||||
// Value node of funcgraph.
|
||||
return GetValueNode<FuncGraphPtr>(node);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
// Partial cnode.
|
||||
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
|
||||
return GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid partial construct node:" << node->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
// Get funcgraph in the output of inner call.
|
||||
if (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) {
|
||||
return GetFuncGraphFromPartial(GetValueNode<FuncGraphPtr>(node)->output(), depth - 1);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
|
||||
return GetFuncGraphFromPartial(GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos])->output(),
|
||||
depth - 1);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -83,7 +83,8 @@ class AnfRuntimeAlgorithm {
|
|||
prim::kPrimMakeTuple});
|
||||
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
|
||||
const std::vector<PrimitivePtr> &return_types = {});
|
||||
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);
|
||||
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node,
|
||||
std::set<AnfNodePtr> *visited_call_nodes = nullptr);
|
||||
// get cnode primitive
|
||||
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
|
||||
static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
|
||||
|
@ -329,6 +330,20 @@ class AnfRuntimeAlgorithm {
|
|||
static void CacheAddrForGraph(const KernelGraphPtr &kernel_graph);
|
||||
static void CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
|
||||
static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
|
||||
// Check whether node is a call node, there are two types of call nodes:
|
||||
// 1. First input of node is a cnode.
|
||||
// 2. First input of node is a funcgraph value node.
|
||||
static bool IsCallNode(const AnfNodePtr &node);
|
||||
// Find all funcgraphs that the call node will call.
|
||||
static std::set<FuncGraphPtr> GetFuncGraphbyCallNode(const AnfNodePtr &node, size_t call_depth = 1);
|
||||
// Check whether node has a partial structure, a node is a partial structure whicih:
|
||||
// 1. a partial cnode.
|
||||
// 2. a funcgraph value node.
|
||||
static bool IsPartial(const AnfNodePtr &node);
|
||||
// Get funcgraph in partial structure.
|
||||
// Depth represents the number of layers of the call. When the first input of the call node is a call node,
|
||||
// the funcgraph in the return value of the inner call needs to be returned.
|
||||
static FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node, size_t depth = 1);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -1373,6 +1373,23 @@ bool KernelGraph::IsDatasetGraph() const {
|
|||
|
||||
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
|
||||
|
||||
bool KernelGraph::IsChildGraphResult(const AnfNodePtr &node) {
|
||||
std::vector<AnfNodePtr> child_graph_results;
|
||||
for (const auto &child_graph_result : child_graph_result_) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph_result);
|
||||
if (AnfAlgo::CheckPrimitiveType(child_graph_result, prim::kPrimMakeTuple)) {
|
||||
const auto cnode = child_graph_result->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &inputs = cnode->inputs();
|
||||
child_graph_results.insert(child_graph_results.end(), inputs.begin(), inputs.end());
|
||||
} else {
|
||||
child_graph_results.emplace_back(child_graph_result);
|
||||
}
|
||||
}
|
||||
|
||||
return find(child_graph_results.begin(), child_graph_results.end(), node) != child_graph_results.end();
|
||||
}
|
||||
|
||||
KernelGraph::~KernelGraph() {
|
||||
try {
|
||||
// Release the kernel resource.
|
||||
|
|
|
@ -260,6 +260,7 @@ class KernelGraph : public FuncGraph {
|
|||
void UpdateChildGraphOrder();
|
||||
const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
|
||||
void AddChildGraphResult(const AnfNodePtr ¶meter) { child_graph_result_.push_back(parameter); }
|
||||
bool IsChildGraphResult(const AnfNodePtr &node);
|
||||
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
|
||||
child_graph_result_ = child_graph_result;
|
||||
}
|
||||
|
|
|
@ -54,6 +54,7 @@ bool CheckValidFuncGraphInput(const AnfNodePtr &node) {
|
|||
|
||||
// Get the funcgraph in partial node.
|
||||
FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
|
||||
return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
|
||||
}
|
||||
|
@ -313,7 +314,7 @@ std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std
|
|||
if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) {
|
||||
return outputs;
|
||||
}
|
||||
if (!IsCallNode(real_output.first)) {
|
||||
if (!AnfAlgo::IsCallNode(real_output.first)) {
|
||||
outputs.push_back(real_output.first);
|
||||
return outputs;
|
||||
}
|
||||
|
@ -349,7 +350,7 @@ std::vector<AnfNodePtr> FetchOutputByCallNode(const AnfNodePtr &call_node, std::
|
|||
} else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) {
|
||||
const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes);
|
||||
(void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
|
||||
} else if (IsCallNode(graph_output)) {
|
||||
} else if (AnfAlgo::IsCallNode(graph_output)) {
|
||||
const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes);
|
||||
(void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
|
||||
} else if (graph_output->isa<CNode>()) {
|
||||
|
@ -388,7 +389,7 @@ std::vector<AnfNodePtr> FetchOutputBySwitchNode(const AnfNodePtr &switch_node, s
|
|||
} else if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
|
||||
const auto &switch_outputs = FetchOutputBySwitchNode(inputs[i], call_nodes, switch_nodes);
|
||||
(void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
|
||||
} else if (IsCallNode(inputs[i])) {
|
||||
} else if (AnfAlgo::IsCallNode(inputs[i])) {
|
||||
const auto &call_outputs = FetchOutputByCallNode(inputs[i], call_nodes, switch_nodes);
|
||||
(void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
|
||||
} else {
|
||||
|
@ -486,7 +487,7 @@ FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
|
|||
|
||||
AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes) {
|
||||
const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
|
||||
if (!IsCallNode(real_node)) {
|
||||
if (!AnfAlgo::IsCallNode(real_node)) {
|
||||
return real_node;
|
||||
}
|
||||
if ((*call_nodes).find(real_node) != (*call_nodes).end()) {
|
||||
|
@ -513,15 +514,6 @@ bool HasAbstractRef(const AnfNodePtr &node) {
|
|||
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
|
||||
}
|
||||
|
||||
bool IsCallNode(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
|
||||
}
|
||||
|
||||
bool IsSubCallNode(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
|
@ -604,7 +596,7 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
|
|||
func_graphs.emplace_back(func_graph);
|
||||
}
|
||||
}
|
||||
} else if (IsCallNode(cnode)) {
|
||||
} else if (AnfAlgo::IsCallNode(cnode)) {
|
||||
return FetchFuncGraphbyCallNode(cnode);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
|
||||
|
@ -618,7 +610,7 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes) {
|
||||
if (!IsCallNode(node)) {
|
||||
if (!AnfAlgo::IsCallNode(node)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid call node:" << AnfAlgo::GetNodeDebugString(node);
|
||||
}
|
||||
if (find((*call_nodes).begin(), (*call_nodes).end(), node) != (*call_nodes).end()) {
|
||||
|
@ -631,7 +623,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr>
|
|||
const auto &output = func_graph->output();
|
||||
const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0);
|
||||
|
||||
if (IsCallNode(real_output.first)) {
|
||||
if (AnfAlgo::IsCallNode(real_output.first)) {
|
||||
size_t output_num = FetchOutputSizebyCallNode(real_output.first, call_nodes);
|
||||
if (output_num > 0) {
|
||||
return output_num;
|
||||
|
@ -642,7 +634,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr>
|
|||
const auto &inputs = tuple_cnode->inputs();
|
||||
size_t i = 1;
|
||||
for (; i < inputs.size(); ++i) {
|
||||
if (IsCallNode(inputs[i])) {
|
||||
if (AnfAlgo::IsCallNode(inputs[i])) {
|
||||
size_t call_output_num = FetchOutputSizebyCallNode(inputs[i], call_nodes);
|
||||
if (call_output_num == 0) {
|
||||
break;
|
||||
|
@ -872,7 +864,7 @@ void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node
|
|||
if (node_value->isa<tensor::Tensor>()) {
|
||||
(void)((*value_nodes).emplace_back(input));
|
||||
}
|
||||
} else if (IsCallNode(input)) {
|
||||
} else if (AnfAlgo::IsCallNode(input)) {
|
||||
// If input is a call not, should check the switch node in its input.
|
||||
const auto &call_node = input->cast<CNodePtr>();
|
||||
const auto &call_inputs = call_node->inputs();
|
||||
|
@ -1050,7 +1042,7 @@ void ControlNodeParser::FetchFrontToFrontParameter(
|
|||
std::vector<AnfNodePtr> call_inputs;
|
||||
call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end());
|
||||
switch_input_parse(inputs[0], call_inputs);
|
||||
} else if (IsCallNode(inputs[0])) {
|
||||
} else if (AnfAlgo::IsCallNode(inputs[0])) {
|
||||
continue;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:"
|
||||
|
@ -1098,7 +1090,7 @@ std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::
|
|||
|
||||
void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes) {
|
||||
for (const auto &control_node : control_nodes) {
|
||||
if (IsCallNode(control_node)) {
|
||||
if (AnfAlgo::IsCallNode(control_node)) {
|
||||
const auto &func_graphs = FetchFuncGraphbyCallNode(control_node);
|
||||
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
|
@ -1123,7 +1115,7 @@ void ControlNodeParser::FetchCallInputKernelGraph(const std::vector<KernelGraphP
|
|||
const auto inputs = graph->input_nodes();
|
||||
for (const auto &input : inputs) {
|
||||
const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input);
|
||||
if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) {
|
||||
if (internal_parameter_with_index.first != nullptr && AnfAlgo::IsCallNode(internal_parameter_with_index.first)) {
|
||||
call_input_kernel_graphs_[graph] = device_context;
|
||||
call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context};
|
||||
}
|
||||
|
@ -1162,12 +1154,12 @@ std::vector<AnfNodePtr> FetchInputParameterbyControlNode(const AnfNodePtr &node,
|
|||
for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
|
||||
if (inputs[i]->isa<Parameter>()) {
|
||||
(void)parameters.emplace_back(inputs[i]);
|
||||
} else if (IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
|
||||
} else if (AnfAlgo::IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
|
||||
const auto &sub_parameters = FetchInputParameterbyControlNode(inputs[i], switch_nodes, call_nodes);
|
||||
(void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
|
||||
}
|
||||
}
|
||||
} else if (IsCallNode(node)) {
|
||||
} else if (AnfAlgo::IsCallNode(node)) {
|
||||
if ((*call_nodes).find(node) != (*call_nodes).end()) {
|
||||
return parameters;
|
||||
}
|
||||
|
@ -1296,7 +1288,7 @@ void ControlNodeParser::FetchFuncGraphToParameter(const std::vector<AnfNodePtr>
|
|||
} else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
|
||||
// Switchlayer node.
|
||||
FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_);
|
||||
} else if (IsCallNode(inputs[0])) {
|
||||
} else if (AnfAlgo::IsCallNode(inputs[0])) {
|
||||
continue;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString();
|
||||
|
@ -1373,7 +1365,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
|
|||
for (const auto &switch_output : switch_outputs) {
|
||||
FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results);
|
||||
}
|
||||
} else if (IsCallNode(front_output)) {
|
||||
} else if (AnfAlgo::IsCallNode(front_output)) {
|
||||
// Output is a call.
|
||||
const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes);
|
||||
|
||||
|
@ -1429,7 +1421,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
|
|||
}
|
||||
} else if (real_parameter->isa<ValueNode>()) {
|
||||
(void)formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0);
|
||||
} else if (IsCallNode(real_parameter)) {
|
||||
} else if (AnfAlgo::IsCallNode(real_parameter)) {
|
||||
const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter);
|
||||
for (const auto func_graph : func_graphs) {
|
||||
FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters);
|
||||
|
|
|
@ -56,11 +56,6 @@ using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNode
|
|||
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
|
||||
using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
|
||||
|
||||
// Check whether node is a call node, there are two types of call nodes:
|
||||
// 1. First input of node is a cnode.
|
||||
// 2. First input of node is a funcgraph value node.
|
||||
bool IsCallNode(const AnfNodePtr &node);
|
||||
|
||||
// Check if the call node is the input of another call node.
|
||||
bool IsSubCallNode(const AnfNodePtr &node);
|
||||
|
||||
|
|
|
@ -1691,8 +1691,9 @@ void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, con
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_type);
|
||||
MS_EXCEPTION_IF_NULL(kernel_name);
|
||||
|
||||
if (graph->is_executing_sink() && ((node == nullptr) || node->isa<CNode>())) {
|
||||
// In sink mode, the data exchange between child graphs is expressed as parameters. These parameters are stored
|
||||
// in the graph and should be obtained from the super kernel actor.
|
||||
if (graph->is_executing_sink() && ((node == nullptr) || node->isa<CNode>() || graph->IsChildGraphResult(node))) {
|
||||
*kernel_type = KernelTransformType::kSuperKernelActor;
|
||||
*kernel_name = graph->ToString() + "_SuperKernelActor";
|
||||
return;
|
||||
|
|
|
@ -980,13 +980,6 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
|
||||
size_t position = 0;
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
|
||||
if (runtime::IsCallNode(root_output)) {
|
||||
std::vector<AnfNodePtr> call_nodes;
|
||||
size_t call_output_num = runtime::FetchOutputSizebyCallNode(root_output, &call_nodes);
|
||||
for (size_t i = 0; i < call_output_num; ++i) {
|
||||
(void)outputs.emplace_back(root_output, i);
|
||||
}
|
||||
}
|
||||
outputs_num = outputs.size();
|
||||
for (const auto &output : outputs) {
|
||||
if (outputs_order.count(output) == 0) {
|
||||
|
|
Loading…
Reference in New Issue