forked from mindspore-Ecosystem/mindspore
split-graph-for-control-sink
This commit is contained in:
parent
bca2b1a055
commit
7d41812b98
|
@ -65,6 +65,7 @@ const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
|
|||
const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
|
||||
const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub");
|
||||
const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select");
|
||||
const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
|
||||
|
||||
const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
|
||||
const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
|
||||
|
|
|
@ -71,6 +71,7 @@ extern const PrimitivePtr kPrimAssign;
|
|||
extern const PrimitivePtr kPrimAssignAdd;
|
||||
extern const PrimitivePtr kPrimAssignSub;
|
||||
extern const PrimitivePtr kPrimSelect;
|
||||
extern const PrimitivePtr kPrimCall;
|
||||
|
||||
extern const PrimitivePtr kPrimDistribute;
|
||||
extern const PrimitivePtr kPrimDot;
|
||||
|
|
|
@ -271,7 +271,9 @@ size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
|
|||
size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TypePtr type = node->Type();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
if (type->isa<Tuple>()) {
|
||||
auto tuple_type = type->cast<TuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_type);
|
||||
|
@ -913,11 +915,66 @@ bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
|
|||
FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
if (value_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto func_graph = value->cast<FuncGraphPtr>();
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
|
||||
MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
auto input1 = call_node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(input1);
|
||||
if (input1->isa<ValueNode>()) {
|
||||
auto value_node = input1->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto kernel_graph = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
return {kernel_graph->cast<KernelGraphPtr>()};
|
||||
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
|
||||
auto switch_node = input1->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_node);
|
||||
MS_LOG(INFO) << "switch : " << switch_node->DebugString();
|
||||
auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr {
|
||||
auto partial = switch_node->input(input_index);
|
||||
MS_EXCEPTION_IF_NULL(partial);
|
||||
auto partial_cnode = partial->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_cnode);
|
||||
auto graph_node = partial_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(graph_node);
|
||||
MS_LOG(INFO) << graph_node->DebugString();
|
||||
auto graph_value_node = graph_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(graph_value_node);
|
||||
auto graph_value = graph_value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(graph_value);
|
||||
auto child_graph = graph_value->cast<KernelGraphPtr>();
|
||||
return child_graph;
|
||||
};
|
||||
return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
|
||||
MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString();
|
||||
}
|
||||
auto input1 = call_node->input(1);
|
||||
if (input1->isa<ValueNode>()) {
|
||||
return false;
|
||||
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
|
||||
return true;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "kernel/kernel_build_info.h"
|
||||
#include "operator/ops.h"
|
||||
#include "utils/contract.h"
|
||||
#include "session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
@ -182,6 +183,8 @@ class AnfRuntimeAlgorithm {
|
|||
static bool IsCommunicationOp(const AnfNodePtr &node);
|
||||
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
|
||||
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
|
||||
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
|
||||
static bool IsSwitchCall(const CNodePtr &call_node);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -156,6 +156,89 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
|
||||
std::vector<CNodePtr> cnodes = {};
|
||||
size_t i = 0;
|
||||
for (const auto anf : anf_nodes) {
|
||||
MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (anf->isa<CNode>()) {
|
||||
cnodes.push_back(anf->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
return std::move(cnodes);
|
||||
}
|
||||
|
||||
std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) {
|
||||
size_t after_call_index = 0;
|
||||
std::vector<std::vector<CNodePtr>> ret;
|
||||
for (size_t i = 0; i < cnodes.size(); i++) {
|
||||
if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) {
|
||||
auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]);
|
||||
// if graph is the true branch of while,no need split graph
|
||||
if (call_kernel_graph.size() == 1 && call_kernel_graph[0] == cur_graph.parent_graph()) {
|
||||
continue;
|
||||
}
|
||||
auto prev_call_list = std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.begin() + i);
|
||||
auto call_list = std::vector<CNodePtr>(1, cnodes[i]);
|
||||
after_call_index = i + 1;
|
||||
ret.push_back(prev_call_list);
|
||||
ret.push_back(call_list);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) {
|
||||
ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end()));
|
||||
}
|
||||
}
|
||||
return std::move(ret);
|
||||
}
|
||||
|
||||
void UpdateRealInput(KernelGraph *graph) {
|
||||
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
|
||||
auto bind_call_partial_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters,
|
||||
const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void {
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
|
||||
if (args.empty()) {
|
||||
return;
|
||||
}
|
||||
if (parameters.size() != args.size()) {
|
||||
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
|
||||
<< " and args size:" << args.size() << " not equal!";
|
||||
}
|
||||
for (size_t i = 0; i < parameters.size(); i++) {
|
||||
MS_LOG(INFO) << "bind paramreter:" << parameters[i]->DebugString() << " ,arg:" << args[i]->DebugString();
|
||||
child_graph->SetRealInput(parameters[i], args[i]);
|
||||
}
|
||||
};
|
||||
for (auto &call_node : call_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node);
|
||||
if (child_graphs.size() == 1) {
|
||||
MS_EXCEPTION_IF_NULL(child_graphs[0]);
|
||||
bind_call_partial_with_parameter(
|
||||
child_graphs[0]->inputs(), std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()),
|
||||
child_graphs[0].get());
|
||||
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2));
|
||||
} else if (child_graphs.size() == 2) {
|
||||
auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> {
|
||||
auto switch_node = call_node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(switch_node);
|
||||
auto switch_cnode = switch_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_cnode);
|
||||
auto partial = switch_cnode->input(input_index);
|
||||
MS_EXCEPTION_IF_NULL(partial);
|
||||
auto partial_cnode = partial->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_cnode);
|
||||
auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end());
|
||||
partial_cnode->set_inputs(
|
||||
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2));
|
||||
return std::move(ret);
|
||||
};
|
||||
bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
|
||||
bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
|
@ -171,7 +254,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
MS_LOG(INFO) << "start";
|
||||
auto graph = ConstructKernelGraph(func_graph);
|
||||
// split switch
|
||||
SplitSwitch(graph.get());
|
||||
SplitGraph(graph);
|
||||
// insert goto labels and label_sets
|
||||
LinkChildGraphs(graph.get());
|
||||
// resource initialize
|
||||
|
@ -1297,5 +1380,107 @@ void AscendSession::SyncInitialTenosrToDevice() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_graph,
|
||||
const std::vector<CNodePtr> &list) {
|
||||
MS_EXCEPTION_IF_NULL(new_kernel_graph);
|
||||
MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id();
|
||||
// count the output of every anf node
|
||||
std::set<AnfNodePtr> has_output_nodes;
|
||||
for (auto &anf_node : list) {
|
||||
for (auto &input : anf_node->inputs()) {
|
||||
(void)has_output_nodes.insert(input);
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
|
||||
new_kernel_graph->set_return(anf_node->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
|
||||
// create new parameter from cnode
|
||||
for (auto &anf_node : list) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
|
||||
auto input = cnode->inputs()[input_idx];
|
||||
if (!input->isa<CNode>()) {
|
||||
cnode->set_input(input_idx, input);
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) {
|
||||
auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get());
|
||||
cnode->set_input(input_idx, new_parameter);
|
||||
new_kernel_graph->SetRealInput(new_parameter, input);
|
||||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id();
|
||||
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
|
||||
int output_idx = 0;
|
||||
for (auto &anf_node : list) {
|
||||
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
|
||||
new_kernel_graph->set_return(anf_node);
|
||||
}
|
||||
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
|
||||
MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString();
|
||||
make_tuple_inputs.push_back(anf_node);
|
||||
}
|
||||
}
|
||||
if (new_kernel_graph->get_return() == nullptr) {
|
||||
new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs));
|
||||
}
|
||||
MS_LOG(INFO) << "end";
|
||||
return new_kernel_graph;
|
||||
}
|
||||
|
||||
void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
|
||||
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
|
||||
// update the root graph child graph order
|
||||
graph->UpdateChildGraphOrder();
|
||||
// get child list from current graph
|
||||
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list);
|
||||
auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr {
|
||||
if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
|
||||
return child_graph_list[0];
|
||||
}
|
||||
// create new child graph
|
||||
auto child_graph = NewKernelGraph();
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
// create new value node to bind child graph
|
||||
auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph));
|
||||
std::vector<AnfNodePtr> new_call_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())),
|
||||
graph_value_node};
|
||||
// set the graph id of all node of child graph
|
||||
for (auto &child_graph_node : child_graph_list) {
|
||||
AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
|
||||
}
|
||||
SplitKernelGraph(child_graph, child_graph_list);
|
||||
auto new_call = graph->NewCNode(new_call_input);
|
||||
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
|
||||
return new_call;
|
||||
};
|
||||
if (child_graph_lists.size() > 1) {
|
||||
for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) {
|
||||
auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]);
|
||||
if (call_index == 0) {
|
||||
auto new_return_primitive =
|
||||
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())));
|
||||
graph->set_return(graph->NewCNode({new_return_primitive, call_node}));
|
||||
continue;
|
||||
}
|
||||
InsertDependToGraph(graph->graph_id(), call_node);
|
||||
}
|
||||
}
|
||||
graph->UpdateChildGraphOrder();
|
||||
UpdateRealInput(graph.get());
|
||||
auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id()));
|
||||
DumpIR(graph_name, graph);
|
||||
MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end";
|
||||
// recurse to split child graph
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
SplitGraph(child_graph);
|
||||
}
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -95,13 +95,16 @@ class AscendSession : public SessionBasic {
|
|||
void SetFinalGraphOutput(const ValuePtr &value);
|
||||
void SetFinalGraphOutput(const VectorRef &vec_output);
|
||||
|
||||
void SplitSwitch(KernelGraph *graph) {}
|
||||
void SplitGraph(const KernelGraphPtr &graph);
|
||||
void LinkChildGraphs(KernelGraph *graph) {}
|
||||
void IRFusion(const KernelGraphPtr &graph) {}
|
||||
void SelectKernelGraphKernel(const KernelGraph &graph) {}
|
||||
void ConvertPredictModel(const KernelGraphPtr graph) {}
|
||||
void HardwareOptimizeGraphs(const KernelGraphPtr graph) {}
|
||||
void RootGraphExecutorValidate(KernelGraph *graph) {}
|
||||
void RecurseUpdateAllChildGraohOrder(KernelGraph *root_graph);
|
||||
KernelGraphPtr SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list);
|
||||
void ChildGraphCommunicationDecrease(std::vector<std::vector<AnfNodePtr>> *anf_node_lists);
|
||||
|
||||
// merge execution order list of child graphs
|
||||
void MergeGraphExecOrder();
|
||||
|
|
|
@ -16,9 +16,8 @@
|
|||
#include "session/kernel_graph.h"
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <stack>
|
||||
#include <unordered_set>
|
||||
#include "common/utils.h"
|
||||
#include <set>
|
||||
#include "operator/ops.h"
|
||||
#include "ir/param_value_py.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
@ -311,9 +310,10 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
|
|||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// set the format of value_node to DEFAULT_FORMAT
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
auto output_tensor_num = AnfAlgo::GetOutputTensorNum(value_node);
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
|
||||
// set value node initial device data type = infer data type
|
||||
std::vector<TypeId> types = std::vector<TypeId>(AnfAlgo::GetOutputTensorNum(value_node), kTypeUnknown);
|
||||
std::vector<TypeId> types = std::vector<TypeId>(output_tensor_num, kTypeUnknown);
|
||||
kernel_build_info_builder->SetOutputsDeviceType(types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
||||
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
|
||||
|
@ -584,7 +584,25 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateChildGraphOrder() {}
|
||||
void KernelGraph::UpdateChildGraphOrder() {
|
||||
MS_LOG(INFO) << "graph id:" << graph_id_;
|
||||
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
|
||||
child_graph_order_.clear();
|
||||
for (auto &call_node : call_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
auto call_child_graphs = AnfAlgo ::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
|
||||
for (const auto &child_graph : call_child_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
if (child_graph != parent_graph()) {
|
||||
child_graph->set_parent_graph(shared_from_this()->cast<std::shared_ptr<KernelGraph>>());
|
||||
child_graph_order_.push_back(child_graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < child_graph_order_.size(); i++) {
|
||||
MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order_[i]->graph_id() << "]";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
|
||||
std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
|
||||
|
@ -601,5 +619,36 @@ std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
|
|||
}
|
||||
|
||||
bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
|
||||
|
||||
std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
|
||||
auto anf_list = TopoSort(get_return());
|
||||
std::vector<CNodePtr> result;
|
||||
for (const auto &anf : anf_list) {
|
||||
if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
|
||||
result.push_back(anf->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
if (real_inputs_.find(parameter) == real_inputs_.end()) {
|
||||
return {};
|
||||
}
|
||||
return real_inputs_[parameter];
|
||||
}
|
||||
|
||||
void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (real_inputs_.find(parameter) == real_inputs_.end()) {
|
||||
real_inputs_[parameter] = std::set<AnfNodePtr>();
|
||||
}
|
||||
auto &args = real_inputs_[parameter];
|
||||
(void)args.insert(arg);
|
||||
}
|
||||
|
||||
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <string>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -113,6 +114,17 @@ class KernelGraph : public FuncGraph {
|
|||
}
|
||||
// get input_tensors pointer of control parameter
|
||||
std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors() const { return input_ctrl_tensors_; }
|
||||
// get parent kernel graph
|
||||
std::shared_ptr<KernelGraph> parent_graph() const { return parent_graph_; }
|
||||
// set parent kernel graph
|
||||
void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
|
||||
// find anf node in graph
|
||||
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
|
||||
// get real inputs
|
||||
std::set<AnfNodePtr> GetRealInput(const AnfNodePtr ¶meter);
|
||||
void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg);
|
||||
// used to dump ir
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
|
@ -158,6 +170,10 @@ class KernelGraph : public FuncGraph {
|
|||
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
|
||||
// input_tensors of control parameter
|
||||
std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_;
|
||||
// parameter graph
|
||||
std::shared_ptr<KernelGraph> parent_graph_;
|
||||
// record real parameters,inputs_ is the formal parameters
|
||||
std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_;
|
||||
};
|
||||
} // namespace session
|
||||
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
||||
|
|
|
@ -247,27 +247,6 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool va
|
|||
return parameters;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (!anf->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode";
|
||||
}
|
||||
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
||||
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
|
||||
if (parameters.empty()) {
|
||||
MS_LOG(EXCEPTION) << "No parameter exist!!";
|
||||
}
|
||||
if (parameters.size() == 1) {
|
||||
return parameters[0];
|
||||
}
|
||||
std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
|
||||
auto make_tuple = graph->NewCNode(make_tuple_input);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
|
||||
return make_tuple;
|
||||
}
|
||||
|
||||
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
|
||||
MS_LOG(INFO) << "Load kInputCtrlTensors";
|
||||
auto inputs_params = graph->input_ctrl_tensors();
|
||||
|
@ -390,6 +369,24 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
|
|||
return new_parameter;
|
||||
}
|
||||
|
||||
AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
||||
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
|
||||
if (parameters.empty()) {
|
||||
MS_LOG(EXCEPTION) << "No parameter exist!!";
|
||||
}
|
||||
if (parameters.size() == 1) {
|
||||
return parameters[0];
|
||||
}
|
||||
std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
|
||||
auto make_tuple = graph->NewCNode(make_tuple_input);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
|
||||
return make_tuple;
|
||||
}
|
||||
|
||||
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
|
||||
bool *from_other_graph,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
||||
|
@ -454,7 +451,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|||
MS_EXCEPTION_IF_NULL(attr_input);
|
||||
if (IsValueNode<FuncGraph>(attr_input)) {
|
||||
// create primitive of cnode:call
|
||||
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))};
|
||||
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
// create a ValueNode<KernelGraph> as input of cnode:call
|
||||
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
|
||||
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
|
||||
|
@ -466,12 +463,10 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|||
}
|
||||
} else if (attr_input->isa<CNode>()) {
|
||||
// create primitive of cnode:call(switch)
|
||||
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))};
|
||||
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
|
||||
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
||||
auto prim = GetCNodePrimitive(cnode_input);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() != kSwitchOpName) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
|
||||
MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
|
||||
}
|
||||
cnode_inputs.emplace_back(cnode_input);
|
||||
|
@ -484,7 +479,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|||
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
// push attr to inputs[0] of new cnode
|
||||
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
|
||||
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
|
||||
}
|
||||
|
||||
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
|
||||
|
@ -545,7 +540,6 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
|
|||
AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
|
||||
|
||||
graph->FrontBackendlMapAdd(anf, new_value_node);
|
||||
graph->AddValueNodeToGraph(new_value_node);
|
||||
|
||||
return new_value_node;
|
||||
}
|
||||
|
@ -555,11 +549,11 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
|
|||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
|
||||
}
|
||||
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
|
||||
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
|
||||
auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
||||
TraceManager::EndTrace();
|
||||
graph_inputs->push_back(new_parameter);
|
||||
graph->FrontBackendlMapAdd(anf, new_parameter);
|
||||
|
||||
|
|
|
@ -114,6 +114,7 @@ class SessionBasic {
|
|||
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
|
||||
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
|
||||
|
||||
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
||||
|
|
Loading…
Reference in New Issue