forked from mindspore-Ecosystem/mindspore
!1186 [control sink refactor]: construct kernel graph
Merge pull request !1186 from wenchunjiang/construct_kernel_graph
This commit is contained in:
commit
4ecc9389e0
|
@ -909,5 +909,15 @@ bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
|
|||
auto kernel_name = AnfAlgo::GetCNodeName(node);
|
||||
return kernel_name == kGetNextOpName;
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto func_graph = value->cast<FuncGraphPtr>();
|
||||
return func_graph;
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -181,6 +181,7 @@ class AnfRuntimeAlgorithm {
|
|||
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
|
||||
static bool IsCommunicationOp(const AnfNodePtr &node);
|
||||
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
|
||||
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -462,6 +462,126 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
|
|||
return new_cnode;
|
||||
}
|
||||
|
||||
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> cnode_inputs;
|
||||
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
||||
MS_EXCEPTION_IF_NULL(attr_input);
|
||||
if (IsValueNode<FuncGraph>(attr_input)) {
|
||||
// create primitive of cnode:call
|
||||
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))};
|
||||
// create a ValueNode<KernelGraph> as input of cnode:call
|
||||
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
|
||||
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
|
||||
} else {
|
||||
auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
|
||||
if (new_value_node != nullptr) {
|
||||
cnode_inputs.emplace_back(new_value_node);
|
||||
}
|
||||
}
|
||||
} else if (attr_input->isa<CNode>()) {
|
||||
// create primitive of cnode:call(switch)
|
||||
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))};
|
||||
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
|
||||
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
||||
auto prim = GetCNodePrimitive(cnode_input);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() != kSwitchOpName) {
|
||||
MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
|
||||
}
|
||||
cnode_inputs.emplace_back(cnode_input);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
|
||||
<< ", but input[0] has not been created.";
|
||||
}
|
||||
} else {
|
||||
// get primitive of old node
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
// push attr to inputs[0] of new cnode
|
||||
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
|
||||
}
|
||||
|
||||
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
|
||||
auto anf = cnode->inputs()[input_idx];
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
// anf has been created before
|
||||
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
||||
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
||||
continue;
|
||||
} else if (anf->isa<ValueNode>()) {
|
||||
if (!IsValueNode<FuncGraph>(anf)) {
|
||||
// if input is a common value node,
|
||||
auto new_value_node = CreateNewValueNode(anf, graph);
|
||||
if (new_value_node != nullptr) {
|
||||
cnode_inputs.emplace_back(new_value_node);
|
||||
}
|
||||
} else {
|
||||
// if input is a ValueNode<FuncGraph>
|
||||
auto new_value_node = CreateValueNodeKernelGraph(anf, graph);
|
||||
if (new_value_node != nullptr) {
|
||||
cnode_inputs.emplace_back(new_value_node);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else if (anf->isa<Parameter>()) {
|
||||
auto new_parameter = CreateNewParameter(anf, graph);
|
||||
cnode_inputs.push_back(new_parameter);
|
||||
continue;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
|
||||
}
|
||||
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
|
||||
auto new_cnode = graph->NewCNode(cnode_inputs);
|
||||
TraceManager::EndTrace();
|
||||
return new_cnode;
|
||||
}
|
||||
|
||||
ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
auto value_node = anf->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
|
||||
MS_EXCEPTION_IF_NULL(sub_func_graph);
|
||||
if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
|
||||
}
|
||||
auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
|
||||
|
||||
ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
|
||||
new_value_node->set_abstract(value_node->abstract());
|
||||
// create new kernel_info of new value_node
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
kernel_info->SetFeatureMapFlag(false);
|
||||
new_value_node->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
||||
AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
|
||||
|
||||
graph->FrontBackendlMapAdd(anf, new_value_node);
|
||||
graph->AddValueNodeToGraph(new_value_node);
|
||||
|
||||
return new_value_node;
|
||||
}
|
||||
|
||||
ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
|
||||
}
|
||||
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
|
||||
auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
||||
graph_inputs->push_back(new_parameter);
|
||||
graph->FrontBackendlMapAdd(anf, new_parameter);
|
||||
|
||||
return new_parameter;
|
||||
}
|
||||
|
||||
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
|
||||
auto graph = NewKernelGraph();
|
||||
|
@ -505,7 +625,69 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|||
return graph;
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &) { return nullptr; }
|
||||
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) {
|
||||
MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph.";
|
||||
return front_backend_graph_map_[func_graph];
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
auto graph = NewKernelGraph();
|
||||
front_backend_graph_map_[func_graph] = graph;
|
||||
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
||||
|
||||
for (const auto &node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(DEBUG) << "Node " << node->DebugString() << " is not CNode";
|
||||
continue;
|
||||
} else {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
// recurse control ops: call, partial
|
||||
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
||||
MS_EXCEPTION_IF_NULL(attr_input);
|
||||
if (IsValueNode<FuncGraph>(attr_input)) {
|
||||
// recurse call subgraph
|
||||
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(attr_input);
|
||||
ConstructKernelGraph(sub_func_graph);
|
||||
} else if (IsValueNode<Primitive>(attr_input)) {
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() == kPartialOpName) {
|
||||
// recurse partial subgraph
|
||||
auto func_graph_node = cnode->input(kAnfPartialFuncGraphIndex);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_node);
|
||||
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node);
|
||||
ConstructKernelGraph(sub_func_graph);
|
||||
}
|
||||
}
|
||||
|
||||
// create a new cnode object
|
||||
auto new_cnode = CreateNewCNode(cnode, graph.get());
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
graph->FrontBackendlMapAdd(node, new_cnode);
|
||||
|
||||
// set original return to kernel_graph
|
||||
if (IsPrimitive(new_cnode->input(kAnfPrimitiveIndex), prim::kPrimReturn)) {
|
||||
graph->set_return(new_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(context_);
|
||||
FuncGraphManagerPtr manager = context_->manager();
|
||||
if (manager) {
|
||||
manager->AddFuncGraph(graph);
|
||||
graph->set_manager(manager);
|
||||
}
|
||||
graph->SetExecOrderByDefault();
|
||||
return graph;
|
||||
}
|
||||
|
||||
// run graph steps
|
||||
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
|
|
|
@ -78,6 +78,7 @@ class SessionBasic {
|
|||
|
||||
CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
||||
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);
|
||||
|
||||
// set parameters of final graph
|
||||
virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; }
|
||||
|
@ -111,9 +112,12 @@ class SessionBasic {
|
|||
// create a new kernel graph and update the graph sum
|
||||
KernelGraphPtr NewKernelGraph();
|
||||
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
|
||||
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
|
||||
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
||||
std::unordered_map<FuncGraphPtr, KernelGraphPtr> front_backend_graph_map_;
|
||||
std::shared_ptr<Context> context_;
|
||||
CallBackFunc summary_callback_;
|
||||
static GraphId graph_sum_;
|
||||
|
|
|
@ -141,6 +141,9 @@ constexpr auto kLabelSetOpName = "LabelSet";
|
|||
constexpr auto kLabelSwitchOpName = "LabelSwitch";
|
||||
constexpr auto kLabelGotoOpName = "LabelGoto";
|
||||
constexpr auto kBNInferGradOpName = "BNInferGrad";
|
||||
constexpr auto kCallOpName = "call";
|
||||
constexpr auto kPartialOpName = "partial";
|
||||
constexpr auto kSwitchOpName = "switch";
|
||||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
@ -197,6 +200,7 @@ const int kValueNodeTensorMask = 2;
|
|||
|
||||
// define special index in special node
|
||||
constexpr auto kAnfPrimitiveIndex = 0;
|
||||
constexpr auto kAnfPartialFuncGraphIndex = 1;
|
||||
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
|
||||
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
|
||||
constexpr auto kTupleGetItemInputSize = 3;
|
||||
|
|
Loading…
Reference in New Issue