forked from mindspore-Ecosystem/mindspore
Get output nodes by control node.
This commit is contained in:
parent
1e119e7756
commit
83ee99b07d
|
@ -49,6 +49,11 @@ namespace {
|
||||||
constexpr size_t kNopNodeInputSize = 2;
|
constexpr size_t kNopNodeInputSize = 2;
|
||||||
constexpr size_t kNopNodeRealInputIndex = 1;
|
constexpr size_t kNopNodeRealInputIndex = 1;
|
||||||
constexpr size_t kReturnDataIndex = 1;
|
constexpr size_t kReturnDataIndex = 1;
|
||||||
|
constexpr size_t kSwitchTrueBranchIndex = 2;
|
||||||
|
constexpr size_t kPartialFuncGraphPos = 1;
|
||||||
|
constexpr size_t kSwitchLayerBranchPos = 2;
|
||||||
|
constexpr size_t kSwitchTrueBranchPos = 2;
|
||||||
|
constexpr size_t kMakeTupleInputStartPos = 1;
|
||||||
|
|
||||||
const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
|
const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
|
||||||
|
|
||||||
|
@ -142,6 +147,54 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
|
||||||
return inputs->push_back(std::make_pair(node, output_index));
|
return inputs->push_back(std::make_pair(node, output_index));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch all outputs of control nodes, visited nodes indicates the call node that has been processed. In control flow,
|
||||||
|
// there are recursive calls between funcgraphs, so the processed call nodes are recorded to prevent infinite loops.
|
||||||
|
std::vector<KernelWithIndex> GetAllOutputByControlFlowNode(const KernelWithIndex &output_with_index,
|
||||||
|
std::set<AnfNodePtr> *visited_call_nodes) {
|
||||||
|
std::vector<KernelWithIndex> ret;
|
||||||
|
const auto &node = output_with_index.first;
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
|
||||||
|
const auto &switch_cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(switch_cnode);
|
||||||
|
const auto &switch_inputs = switch_cnode->inputs();
|
||||||
|
auto output_vector = AnfAlgo::GetAllOutputWithIndex(switch_inputs[kSwitchTrueBranchIndex], visited_call_nodes);
|
||||||
|
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
|
||||||
|
} else if (AnfAlgo::IsCallNode(node)) {
|
||||||
|
if (visited_call_nodes != nullptr) {
|
||||||
|
if (visited_call_nodes->find(node) != visited_call_nodes->end()) {
|
||||||
|
return ret;
|
||||||
|
} else {
|
||||||
|
visited_call_nodes->emplace(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The output of the call node is the output of the funcgraph actually called.
|
||||||
|
const auto &func_graphs = AnfAlgo::GetFuncGraphbyCallNode(node);
|
||||||
|
for (const auto &func_graph : func_graphs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
// The call in the graph kernel does not need to be parsed, and the node is directly output.
|
||||||
|
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||||
|
ret.emplace_back(output_with_index);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph->output());
|
||||||
|
const auto &func_graph_output =
|
||||||
|
AnfAlgo::VisitKernelWithReturnType(func_graph->output(), output_with_index.second);
|
||||||
|
std::set<AnfNodePtr> tmp_visited_nodes = {node};
|
||||||
|
auto output_vector = AnfAlgo::GetAllOutputWithIndex(
|
||||||
|
func_graph_output.first, (visited_call_nodes == nullptr ? &tmp_visited_nodes : visited_call_nodes));
|
||||||
|
if (output_with_index.second < output_vector.size()) {
|
||||||
|
ret.emplace_back(output_vector[output_with_index.second]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
// ops pair that dynamic input order is differ from the fixed shape ops
|
// ops pair that dynamic input order is differ from the fixed shape ops
|
||||||
// pair: <real_input->ori_input, ori_input->real_input>
|
// pair: <real_input->ori_input, ori_input->real_input>
|
||||||
static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = {
|
static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = {
|
||||||
|
@ -339,7 +392,8 @@ std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) {
|
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node,
|
||||||
|
std::set<AnfNodePtr> *visited_call_nodes) {
|
||||||
std::vector<KernelWithIndex> ret;
|
std::vector<KernelWithIndex> ret;
|
||||||
std::vector<KernelWithIndex> ret_empty;
|
std::vector<KernelWithIndex> ret_empty;
|
||||||
|
|
||||||
|
@ -348,7 +402,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
||||||
auto make_tuple = node->cast<CNodePtr>();
|
auto make_tuple = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||||
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
|
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
|
||||||
auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i));
|
auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i), visited_call_nodes);
|
||||||
(void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
|
(void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret));
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -358,7 +412,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
||||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
||||||
auto depend_node = node->cast<CNodePtr>();
|
auto depend_node = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(depend_node);
|
MS_EXCEPTION_IF_NULL(depend_node);
|
||||||
auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend));
|
auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend), visited_call_nodes);
|
||||||
(void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
|
(void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret));
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -393,20 +447,16 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
|
||||||
|
|
||||||
// The makeTuple node need recurse.
|
// The makeTuple node need recurse.
|
||||||
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
|
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
|
||||||
auto output_vector = GetAllOutputWithIndex(output_with_index.first);
|
auto output_vector = GetAllOutputWithIndex(output_with_index.first, visited_call_nodes);
|
||||||
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
|
(void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ignore the output of front call node.
|
// Fetch outputs by control nodes.
|
||||||
if (output_with_index.first->isa<CNode>()) {
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || AnfAlgo::IsCallNode(node)) {
|
||||||
auto cnode = output_with_index.first->cast<CNodePtr>();
|
const auto &control_node_output = GetAllOutputByControlFlowNode(output_with_index, visited_call_nodes);
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
(void)std::copy(control_node_output.begin(), control_node_output.end(), std::back_inserter(ret));
|
||||||
auto inputs = cnode->inputs();
|
continue;
|
||||||
if (inputs[0]->isa<CNode>()) {
|
|
||||||
MS_LOG(INFO) << "The output is call node: " << output_with_index.first->DebugString();
|
|
||||||
return ret_empty;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The InitDataSetQueue node has no output.
|
// The InitDataSetQueue node has no output.
|
||||||
|
@ -2527,5 +2577,100 @@ size_t OpRuntimeInfo::output_tensor_size(size_t index) const {
|
||||||
}
|
}
|
||||||
return output_tensor_size_[index];
|
return output_tensor_size_[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AnfRuntimeAlgorithm::IsCallNode(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto &cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
|
||||||
|
const auto &inputs = cnode->inputs();
|
||||||
|
if (inputs.empty() || inputs[0] == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid call node:" << node->DebugString();
|
||||||
|
}
|
||||||
|
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::set<FuncGraphPtr> AnfRuntimeAlgorithm::GetFuncGraphbyCallNode(const AnfNodePtr &node, size_t call_depth) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
std::set<FuncGraphPtr> func_graphs;
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
return func_graphs;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
const auto &call_input0 = cnode->input(0);
|
||||||
|
MS_EXCEPTION_IF_NULL(call_input0);
|
||||||
|
|
||||||
|
if (AnfAlgo::IsCallNode(call_input0)) {
|
||||||
|
return AnfAlgo::GetFuncGraphbyCallNode(call_input0, ++call_depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitch)) {
|
||||||
|
// First input node of call is switch node.
|
||||||
|
const auto &switch_inputs = call_input0->cast<CNodePtr>()->inputs();
|
||||||
|
for (size_t i = kSwitchTrueBranchPos; i < switch_inputs.size(); ++i) {
|
||||||
|
MS_EXCEPTION_IF_NULL(switch_inputs[i]);
|
||||||
|
(void)func_graphs.emplace(GetFuncGraphFromPartial(switch_inputs[i], call_depth));
|
||||||
|
}
|
||||||
|
} else if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitchLayer)) {
|
||||||
|
// First input node of call is switch layer node.
|
||||||
|
const auto &tuple_node = cnode->cast<CNodePtr>()->input(kSwitchLayerBranchPos);
|
||||||
|
if (!AnfAlgo::CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid input tuple node:" << tuple_node->DebugString()
|
||||||
|
<< " for switch layer node:" << cnode->DebugString();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &tuple_inputs = tuple_node->cast<CNodePtr>()->inputs();
|
||||||
|
for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) {
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_inputs[i]);
|
||||||
|
func_graphs.emplace(GetFuncGraphFromPartial(tuple_inputs[i], call_depth));
|
||||||
|
}
|
||||||
|
} else if (IsPartial(call_input0)) {
|
||||||
|
// First input node of call is partial node or value node of funcgraph.
|
||||||
|
(void)func_graphs.emplace(GetFuncGraphFromPartial(call_input0, call_depth));
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
|
||||||
|
}
|
||||||
|
return func_graphs;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfRuntimeAlgorithm::IsPartial(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
return (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) ||
|
||||||
|
AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial);
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr AnfRuntimeAlgorithm::GetFuncGraphFromPartial(const AnfNodePtr &node, size_t depth) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (depth == 1) {
|
||||||
|
if (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) {
|
||||||
|
// Value node of funcgraph.
|
||||||
|
return GetValueNode<FuncGraphPtr>(node);
|
||||||
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||||
|
// Partial cnode.
|
||||||
|
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
|
||||||
|
return GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid partial construct node:" << node->DebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get funcgraph in the output of inner call.
|
||||||
|
if (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) {
|
||||||
|
return GetFuncGraphFromPartial(GetValueNode<FuncGraphPtr>(node)->output(), depth - 1);
|
||||||
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||||
|
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
|
||||||
|
return GetFuncGraphFromPartial(GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos])->output(),
|
||||||
|
depth - 1);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString();
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -83,7 +83,8 @@ class AnfRuntimeAlgorithm {
|
||||||
prim::kPrimMakeTuple});
|
prim::kPrimMakeTuple});
|
||||||
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
|
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
|
||||||
const std::vector<PrimitivePtr> &return_types = {});
|
const std::vector<PrimitivePtr> &return_types = {});
|
||||||
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);
|
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node,
|
||||||
|
std::set<AnfNodePtr> *visited_call_nodes = nullptr);
|
||||||
// get cnode primitive
|
// get cnode primitive
|
||||||
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
|
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
|
||||||
static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
|
static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
|
||||||
|
@ -329,6 +330,20 @@ class AnfRuntimeAlgorithm {
|
||||||
static void CacheAddrForGraph(const KernelGraphPtr &kernel_graph);
|
static void CacheAddrForGraph(const KernelGraphPtr &kernel_graph);
|
||||||
static void CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
|
static void CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
|
||||||
static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
|
static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
|
||||||
|
// Check whether node is a call node, there are two types of call nodes:
|
||||||
|
// 1. First input of node is a cnode.
|
||||||
|
// 2. First input of node is a funcgraph value node.
|
||||||
|
static bool IsCallNode(const AnfNodePtr &node);
|
||||||
|
// Find all funcgraphs that the call node will call.
|
||||||
|
static std::set<FuncGraphPtr> GetFuncGraphbyCallNode(const AnfNodePtr &node, size_t call_depth = 1);
|
||||||
|
// Check whether node has a partial structure, a node is a partial structure whicih:
|
||||||
|
// 1. a partial cnode.
|
||||||
|
// 2. a funcgraph value node.
|
||||||
|
static bool IsPartial(const AnfNodePtr &node);
|
||||||
|
// Get funcgraph in partial structure.
|
||||||
|
// Depth represents the number of layers of the call. When the first input of the call node is a call node,
|
||||||
|
// the funcgraph in the return value of the inner call needs to be returned.
|
||||||
|
static FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node, size_t depth = 1);
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||||
|
|
|
@ -1373,6 +1373,23 @@ bool KernelGraph::IsDatasetGraph() const {
|
||||||
|
|
||||||
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
|
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
|
||||||
|
|
||||||
|
bool KernelGraph::IsChildGraphResult(const AnfNodePtr &node) {
|
||||||
|
std::vector<AnfNodePtr> child_graph_results;
|
||||||
|
for (const auto &child_graph_result : child_graph_result_) {
|
||||||
|
MS_EXCEPTION_IF_NULL(child_graph_result);
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(child_graph_result, prim::kPrimMakeTuple)) {
|
||||||
|
const auto cnode = child_graph_result->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
const auto &inputs = cnode->inputs();
|
||||||
|
child_graph_results.insert(child_graph_results.end(), inputs.begin(), inputs.end());
|
||||||
|
} else {
|
||||||
|
child_graph_results.emplace_back(child_graph_result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return find(child_graph_results.begin(), child_graph_results.end(), node) != child_graph_results.end();
|
||||||
|
}
|
||||||
|
|
||||||
KernelGraph::~KernelGraph() {
|
KernelGraph::~KernelGraph() {
|
||||||
try {
|
try {
|
||||||
// Release the kernel resource.
|
// Release the kernel resource.
|
||||||
|
|
|
@ -260,6 +260,7 @@ class KernelGraph : public FuncGraph {
|
||||||
void UpdateChildGraphOrder();
|
void UpdateChildGraphOrder();
|
||||||
const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
|
const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
|
||||||
void AddChildGraphResult(const AnfNodePtr ¶meter) { child_graph_result_.push_back(parameter); }
|
void AddChildGraphResult(const AnfNodePtr ¶meter) { child_graph_result_.push_back(parameter); }
|
||||||
|
bool IsChildGraphResult(const AnfNodePtr &node);
|
||||||
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
|
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
|
||||||
child_graph_result_ = child_graph_result;
|
child_graph_result_ = child_graph_result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,6 +54,7 @@ bool CheckValidFuncGraphInput(const AnfNodePtr &node) {
|
||||||
|
|
||||||
// Get the funcgraph in partial node.
|
// Get the funcgraph in partial node.
|
||||||
FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) {
|
FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
|
const auto &partial_inputs = node->cast<CNodePtr>()->inputs();
|
||||||
return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
|
return GetValueNode<FuncGraphPtr>(partial_inputs[1]);
|
||||||
}
|
}
|
||||||
|
@ -313,7 +314,7 @@ std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std
|
||||||
if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) {
|
if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) {
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
if (!IsCallNode(real_output.first)) {
|
if (!AnfAlgo::IsCallNode(real_output.first)) {
|
||||||
outputs.push_back(real_output.first);
|
outputs.push_back(real_output.first);
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
@ -349,7 +350,7 @@ std::vector<AnfNodePtr> FetchOutputByCallNode(const AnfNodePtr &call_node, std::
|
||||||
} else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) {
|
} else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) {
|
||||||
const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes);
|
const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes);
|
||||||
(void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
|
(void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
|
||||||
} else if (IsCallNode(graph_output)) {
|
} else if (AnfAlgo::IsCallNode(graph_output)) {
|
||||||
const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes);
|
const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes);
|
||||||
(void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
|
(void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
|
||||||
} else if (graph_output->isa<CNode>()) {
|
} else if (graph_output->isa<CNode>()) {
|
||||||
|
@ -388,7 +389,7 @@ std::vector<AnfNodePtr> FetchOutputBySwitchNode(const AnfNodePtr &switch_node, s
|
||||||
} else if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
|
} else if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
|
||||||
const auto &switch_outputs = FetchOutputBySwitchNode(inputs[i], call_nodes, switch_nodes);
|
const auto &switch_outputs = FetchOutputBySwitchNode(inputs[i], call_nodes, switch_nodes);
|
||||||
(void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
|
(void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end());
|
||||||
} else if (IsCallNode(inputs[i])) {
|
} else if (AnfAlgo::IsCallNode(inputs[i])) {
|
||||||
const auto &call_outputs = FetchOutputByCallNode(inputs[i], call_nodes, switch_nodes);
|
const auto &call_outputs = FetchOutputByCallNode(inputs[i], call_nodes, switch_nodes);
|
||||||
(void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
|
(void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end());
|
||||||
} else {
|
} else {
|
||||||
|
@ -486,7 +487,7 @@ FuncGraphPtr FetchFuncGraphInNode(const auto &node) {
|
||||||
|
|
||||||
AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes) {
|
AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes) {
|
||||||
const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
|
const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
|
||||||
if (!IsCallNode(real_node)) {
|
if (!AnfAlgo::IsCallNode(real_node)) {
|
||||||
return real_node;
|
return real_node;
|
||||||
}
|
}
|
||||||
if ((*call_nodes).find(real_node) != (*call_nodes).end()) {
|
if ((*call_nodes).find(real_node) != (*call_nodes).end()) {
|
||||||
|
@ -513,15 +514,6 @@ bool HasAbstractRef(const AnfNodePtr &node) {
|
||||||
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
|
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsCallNode(const AnfNodePtr &node) {
|
|
||||||
if (!node->isa<CNode>()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const auto &cnode = node->cast<CNodePtr>();
|
|
||||||
const auto &inputs = cnode->inputs();
|
|
||||||
return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0]));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsSubCallNode(const AnfNodePtr &node) {
|
bool IsSubCallNode(const AnfNodePtr &node) {
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -604,7 +596,7 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
|
||||||
func_graphs.emplace_back(func_graph);
|
func_graphs.emplace_back(func_graph);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (IsCallNode(cnode)) {
|
} else if (AnfAlgo::IsCallNode(cnode)) {
|
||||||
return FetchFuncGraphbyCallNode(cnode);
|
return FetchFuncGraphbyCallNode(cnode);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
|
MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString();
|
||||||
|
@ -618,7 +610,7 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes) {
|
size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes) {
|
||||||
if (!IsCallNode(node)) {
|
if (!AnfAlgo::IsCallNode(node)) {
|
||||||
MS_LOG(EXCEPTION) << "Invalid call node:" << AnfAlgo::GetNodeDebugString(node);
|
MS_LOG(EXCEPTION) << "Invalid call node:" << AnfAlgo::GetNodeDebugString(node);
|
||||||
}
|
}
|
||||||
if (find((*call_nodes).begin(), (*call_nodes).end(), node) != (*call_nodes).end()) {
|
if (find((*call_nodes).begin(), (*call_nodes).end(), node) != (*call_nodes).end()) {
|
||||||
|
@ -631,7 +623,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr>
|
||||||
const auto &output = func_graph->output();
|
const auto &output = func_graph->output();
|
||||||
const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0);
|
const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0);
|
||||||
|
|
||||||
if (IsCallNode(real_output.first)) {
|
if (AnfAlgo::IsCallNode(real_output.first)) {
|
||||||
size_t output_num = FetchOutputSizebyCallNode(real_output.first, call_nodes);
|
size_t output_num = FetchOutputSizebyCallNode(real_output.first, call_nodes);
|
||||||
if (output_num > 0) {
|
if (output_num > 0) {
|
||||||
return output_num;
|
return output_num;
|
||||||
|
@ -642,7 +634,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr>
|
||||||
const auto &inputs = tuple_cnode->inputs();
|
const auto &inputs = tuple_cnode->inputs();
|
||||||
size_t i = 1;
|
size_t i = 1;
|
||||||
for (; i < inputs.size(); ++i) {
|
for (; i < inputs.size(); ++i) {
|
||||||
if (IsCallNode(inputs[i])) {
|
if (AnfAlgo::IsCallNode(inputs[i])) {
|
||||||
size_t call_output_num = FetchOutputSizebyCallNode(inputs[i], call_nodes);
|
size_t call_output_num = FetchOutputSizebyCallNode(inputs[i], call_nodes);
|
||||||
if (call_output_num == 0) {
|
if (call_output_num == 0) {
|
||||||
break;
|
break;
|
||||||
|
@ -872,7 +864,7 @@ void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node
|
||||||
if (node_value->isa<tensor::Tensor>()) {
|
if (node_value->isa<tensor::Tensor>()) {
|
||||||
(void)((*value_nodes).emplace_back(input));
|
(void)((*value_nodes).emplace_back(input));
|
||||||
}
|
}
|
||||||
} else if (IsCallNode(input)) {
|
} else if (AnfAlgo::IsCallNode(input)) {
|
||||||
// If input is a call not, should check the switch node in its input.
|
// If input is a call not, should check the switch node in its input.
|
||||||
const auto &call_node = input->cast<CNodePtr>();
|
const auto &call_node = input->cast<CNodePtr>();
|
||||||
const auto &call_inputs = call_node->inputs();
|
const auto &call_inputs = call_node->inputs();
|
||||||
|
@ -1050,7 +1042,7 @@ void ControlNodeParser::FetchFrontToFrontParameter(
|
||||||
std::vector<AnfNodePtr> call_inputs;
|
std::vector<AnfNodePtr> call_inputs;
|
||||||
call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end());
|
call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end());
|
||||||
switch_input_parse(inputs[0], call_inputs);
|
switch_input_parse(inputs[0], call_inputs);
|
||||||
} else if (IsCallNode(inputs[0])) {
|
} else if (AnfAlgo::IsCallNode(inputs[0])) {
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:"
|
MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:"
|
||||||
|
@ -1098,7 +1090,7 @@ std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std::
|
||||||
|
|
||||||
void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes) {
|
void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes) {
|
||||||
for (const auto &control_node : control_nodes) {
|
for (const auto &control_node : control_nodes) {
|
||||||
if (IsCallNode(control_node)) {
|
if (AnfAlgo::IsCallNode(control_node)) {
|
||||||
const auto &func_graphs = FetchFuncGraphbyCallNode(control_node);
|
const auto &func_graphs = FetchFuncGraphbyCallNode(control_node);
|
||||||
|
|
||||||
for (const auto &func_graph : func_graphs) {
|
for (const auto &func_graph : func_graphs) {
|
||||||
|
@ -1123,7 +1115,7 @@ void ControlNodeParser::FetchCallInputKernelGraph(const std::vector<KernelGraphP
|
||||||
const auto inputs = graph->input_nodes();
|
const auto inputs = graph->input_nodes();
|
||||||
for (const auto &input : inputs) {
|
for (const auto &input : inputs) {
|
||||||
const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input);
|
const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input);
|
||||||
if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) {
|
if (internal_parameter_with_index.first != nullptr && AnfAlgo::IsCallNode(internal_parameter_with_index.first)) {
|
||||||
call_input_kernel_graphs_[graph] = device_context;
|
call_input_kernel_graphs_[graph] = device_context;
|
||||||
call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context};
|
call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context};
|
||||||
}
|
}
|
||||||
|
@ -1162,12 +1154,12 @@ std::vector<AnfNodePtr> FetchInputParameterbyControlNode(const AnfNodePtr &node,
|
||||||
for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
|
for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) {
|
||||||
if (inputs[i]->isa<Parameter>()) {
|
if (inputs[i]->isa<Parameter>()) {
|
||||||
(void)parameters.emplace_back(inputs[i]);
|
(void)parameters.emplace_back(inputs[i]);
|
||||||
} else if (IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
|
} else if (AnfAlgo::IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) {
|
||||||
const auto &sub_parameters = FetchInputParameterbyControlNode(inputs[i], switch_nodes, call_nodes);
|
const auto &sub_parameters = FetchInputParameterbyControlNode(inputs[i], switch_nodes, call_nodes);
|
||||||
(void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
|
(void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (IsCallNode(node)) {
|
} else if (AnfAlgo::IsCallNode(node)) {
|
||||||
if ((*call_nodes).find(node) != (*call_nodes).end()) {
|
if ((*call_nodes).find(node) != (*call_nodes).end()) {
|
||||||
return parameters;
|
return parameters;
|
||||||
}
|
}
|
||||||
|
@ -1296,7 +1288,7 @@ void ControlNodeParser::FetchFuncGraphToParameter(const std::vector<AnfNodePtr>
|
||||||
} else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
|
} else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) {
|
||||||
// Switchlayer node.
|
// Switchlayer node.
|
||||||
FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_);
|
FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_);
|
||||||
} else if (IsCallNode(inputs[0])) {
|
} else if (AnfAlgo::IsCallNode(inputs[0])) {
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString();
|
MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString();
|
||||||
|
@ -1373,7 +1365,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_
|
||||||
for (const auto &switch_output : switch_outputs) {
|
for (const auto &switch_output : switch_outputs) {
|
||||||
FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results);
|
FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results);
|
||||||
}
|
}
|
||||||
} else if (IsCallNode(front_output)) {
|
} else if (AnfAlgo::IsCallNode(front_output)) {
|
||||||
// Output is a call.
|
// Output is a call.
|
||||||
const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes);
|
const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes);
|
||||||
|
|
||||||
|
@ -1429,7 +1421,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode(
|
||||||
}
|
}
|
||||||
} else if (real_parameter->isa<ValueNode>()) {
|
} else if (real_parameter->isa<ValueNode>()) {
|
||||||
(void)formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0);
|
(void)formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0);
|
||||||
} else if (IsCallNode(real_parameter)) {
|
} else if (AnfAlgo::IsCallNode(real_parameter)) {
|
||||||
const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter);
|
const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter);
|
||||||
for (const auto func_graph : func_graphs) {
|
for (const auto func_graph : func_graphs) {
|
||||||
FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters);
|
FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters);
|
||||||
|
|
|
@ -56,11 +56,6 @@ using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNode
|
||||||
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
|
using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>;
|
||||||
using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
|
using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>;
|
||||||
|
|
||||||
// Check whether node is a call node, there are two types of call nodes:
|
|
||||||
// 1. First input of node is a cnode.
|
|
||||||
// 2. First input of node is a funcgraph value node.
|
|
||||||
bool IsCallNode(const AnfNodePtr &node);
|
|
||||||
|
|
||||||
// Check if the call node is the input of another call node.
|
// Check if the call node is the input of another call node.
|
||||||
bool IsSubCallNode(const AnfNodePtr &node);
|
bool IsSubCallNode(const AnfNodePtr &node);
|
||||||
|
|
||||||
|
|
|
@ -1689,8 +1689,9 @@ void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, con
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_type);
|
MS_EXCEPTION_IF_NULL(kernel_type);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_name);
|
MS_EXCEPTION_IF_NULL(kernel_name);
|
||||||
|
// In sink mode, the data exchange between child graphs is expressed as parameters. These parameters are stored
|
||||||
if (graph->is_executing_sink() && ((node == nullptr) || node->isa<CNode>())) {
|
// in the graph and should be obtained from the super kernel actor.
|
||||||
|
if (graph->is_executing_sink() && ((node == nullptr) || node->isa<CNode>() || graph->IsChildGraphResult(node))) {
|
||||||
*kernel_type = KernelTransformType::kSuperKernelActor;
|
*kernel_type = KernelTransformType::kSuperKernelActor;
|
||||||
*kernel_name = graph->ToString() + "_SuperKernelActor";
|
*kernel_name = graph->ToString() + "_SuperKernelActor";
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -980,13 +980,6 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
||||||
AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
|
AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
|
||||||
size_t position = 0;
|
size_t position = 0;
|
||||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
|
auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
|
||||||
if (runtime::IsCallNode(root_output)) {
|
|
||||||
std::vector<AnfNodePtr> call_nodes;
|
|
||||||
size_t call_output_num = runtime::FetchOutputSizebyCallNode(root_output, &call_nodes);
|
|
||||||
for (size_t i = 0; i < call_output_num; ++i) {
|
|
||||||
(void)outputs.emplace_back(root_output, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
outputs_num = outputs.size();
|
outputs_num = outputs.size();
|
||||||
for (const auto &output : outputs) {
|
for (const auto &output : outputs) {
|
||||||
if (outputs_order.count(output) == 0) {
|
if (outputs_order.count(output) == 0) {
|
||||||
|
|
Loading…
Reference in New Issue