split-graph-for-control-sink

This commit is contained in:
chenfei 2020-05-16 20:57:09 +08:00
parent bca2b1a055
commit 7d41812b98
10 changed files with 350 additions and 40 deletions

View File

@ -65,6 +65,7 @@ const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub");
const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select");
const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");

View File

@ -71,6 +71,7 @@ extern const PrimitivePtr kPrimAssign;
extern const PrimitivePtr kPrimAssignAdd;
extern const PrimitivePtr kPrimAssignSub;
extern const PrimitivePtr kPrimSelect;
extern const PrimitivePtr kPrimCall;
extern const PrimitivePtr kPrimDistribute;
extern const PrimitivePtr kPrimDot;

View File

@ -271,7 +271,9 @@ size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
TypePtr type = node->Type();
MS_EXCEPTION_IF_NULL(type);
if (type == nullptr) {
return 0;
}
if (type->isa<Tuple>()) {
auto tuple_type = type->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_type);
@ -913,11 +915,66 @@ bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
if (value_node == nullptr) {
return nullptr;
}
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (value == nullptr) {
return nullptr;
}
auto func_graph = value->cast<FuncGraphPtr>();
return func_graph;
}
std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) {
if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node.";
}
MS_EXCEPTION_IF_NULL(call_node);
auto input1 = call_node->input(1);
MS_EXCEPTION_IF_NULL(input1);
if (input1->isa<ValueNode>()) {
auto value_node = input1->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto kernel_graph = value_node->value();
MS_EXCEPTION_IF_NULL(kernel_graph);
return {kernel_graph->cast<KernelGraphPtr>()};
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
auto switch_node = input1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_node);
MS_LOG(INFO) << "switch : " << switch_node->DebugString();
auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr {
auto partial = switch_node->input(input_index);
MS_EXCEPTION_IF_NULL(partial);
auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto graph_node = partial_cnode->input(1);
MS_EXCEPTION_IF_NULL(graph_node);
MS_LOG(INFO) << graph_node->DebugString();
auto graph_value_node = graph_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(graph_value_node);
auto graph_value = graph_value_node->value();
MS_EXCEPTION_IF_NULL(graph_value);
auto child_graph = graph_value->cast<KernelGraphPtr>();
return child_graph;
};
return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)};
}
return {};
}
bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
MS_EXCEPTION_IF_NULL(call_node);
if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString();
}
auto input1 = call_node->input(1);
if (input1->isa<ValueNode>()) {
return false;
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
return true;
}
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
}
} // namespace session
} // namespace mindspore

View File

@ -32,6 +32,7 @@
#include "kernel/kernel_build_info.h"
#include "operator/ops.h"
#include "utils/contract.h"
#include "session/kernel_graph.h"
namespace mindspore {
namespace session {
@ -182,6 +183,8 @@ class AnfRuntimeAlgorithm {
static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
static bool IsSwitchCall(const CNodePtr &call_node);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -156,6 +156,89 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) {
}
}
}
std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
std::vector<CNodePtr> cnodes = {};
size_t i = 0;
for (const auto anf : anf_nodes) {
MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString();
MS_EXCEPTION_IF_NULL(anf);
if (anf->isa<CNode>()) {
cnodes.push_back(anf->cast<CNodePtr>());
}
}
return std::move(cnodes);
}
std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) {
size_t after_call_index = 0;
std::vector<std::vector<CNodePtr>> ret;
for (size_t i = 0; i < cnodes.size(); i++) {
if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) {
auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]);
// if graph is the true branch of while,no need split graph
if (call_kernel_graph.size() == 1 && call_kernel_graph[0] == cur_graph.parent_graph()) {
continue;
}
auto prev_call_list = std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.begin() + i);
auto call_list = std::vector<CNodePtr>(1, cnodes[i]);
after_call_index = i + 1;
ret.push_back(prev_call_list);
ret.push_back(call_list);
} else if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) {
ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end()));
}
}
return std::move(ret);
}
void UpdateRealInput(KernelGraph *graph) {
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
auto bind_call_partial_with_parameter = [&](const std::vector<AnfNodePtr> &parameters,
const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void {
MS_EXCEPTION_IF_NULL(child_graph);
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
if (args.empty()) {
return;
}
if (parameters.size() != args.size()) {
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
<< " and args size:" << args.size() << " not equal!";
}
for (size_t i = 0; i < parameters.size(); i++) {
MS_LOG(INFO) << "bind paramreter:" << parameters[i]->DebugString() << " ,arg:" << args[i]->DebugString();
child_graph->SetRealInput(parameters[i], args[i]);
}
};
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node);
if (child_graphs.size() == 1) {
MS_EXCEPTION_IF_NULL(child_graphs[0]);
bind_call_partial_with_parameter(
child_graphs[0]->inputs(), std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()),
child_graphs[0].get());
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2));
} else if (child_graphs.size() == 2) {
auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> {
auto switch_node = call_node->input(1);
MS_EXCEPTION_IF_NULL(switch_node);
auto switch_cnode = switch_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
auto partial = switch_cnode->input(input_index);
MS_EXCEPTION_IF_NULL(partial);
auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end());
partial_cnode->set_inputs(
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2));
return std::move(ret);
};
bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
}
}
}
} // namespace
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
@ -171,7 +254,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph);
// split switch
SplitSwitch(graph.get());
SplitGraph(graph);
// insert goto labels and label_sets
LinkChildGraphs(graph.get());
// resource initialize
@ -1297,5 +1380,107 @@ void AscendSession::SyncInitialTenosrToDevice() {
}
}
}
KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_graph,
const std::vector<CNodePtr> &list) {
MS_EXCEPTION_IF_NULL(new_kernel_graph);
MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id();
// count the output of every anf node
std::set<AnfNodePtr> has_output_nodes;
for (auto &anf_node : list) {
for (auto &input : anf_node->inputs()) {
(void)has_output_nodes.insert(input);
}
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
new_kernel_graph->set_return(anf_node->cast<CNodePtr>());
}
}
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
// create new parameter from cnode
for (auto &anf_node : list) {
auto cnode = anf_node->cast<CNodePtr>();
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto input = cnode->inputs()[input_idx];
if (!input->isa<CNode>()) {
cnode->set_input(input_idx, input);
continue;
}
if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) {
auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get());
cnode->set_input(input_idx, new_parameter);
new_kernel_graph->SetRealInput(new_parameter, input);
}
}
}
MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id();
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
int output_idx = 0;
for (auto &anf_node : list) {
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
new_kernel_graph->set_return(anf_node);
}
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString();
make_tuple_inputs.push_back(anf_node);
}
}
if (new_kernel_graph->get_return() == nullptr) {
new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs));
}
MS_LOG(INFO) << "end";
return new_kernel_graph;
}
void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
MS_EXCEPTION_IF_NULL(graph);
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
// update the root graph child graph order
graph->UpdateChildGraphOrder();
// get child list from current graph
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list);
auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr {
if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
return child_graph_list[0];
}
// create new child graph
auto child_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(child_graph);
// create new value node to bind child graph
auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph));
std::vector<AnfNodePtr> new_call_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())),
graph_value_node};
// set the graph id of all node of child graph
for (auto &child_graph_node : child_graph_list) {
AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
}
SplitKernelGraph(child_graph, child_graph_list);
auto new_call = graph->NewCNode(new_call_input);
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
return new_call;
};
if (child_graph_lists.size() > 1) {
for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) {
auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]);
if (call_index == 0) {
auto new_return_primitive =
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())));
graph->set_return(graph->NewCNode({new_return_primitive, call_node}));
continue;
}
InsertDependToGraph(graph->graph_id(), call_node);
}
}
graph->UpdateChildGraphOrder();
UpdateRealInput(graph.get());
auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id()));
DumpIR(graph_name, graph);
MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end";
// recurse to split child graph
for (auto &child_graph : graph->child_graph_order()) {
SplitGraph(child_graph);
}
}
} // namespace session
} // namespace mindspore

View File

@ -95,13 +95,16 @@ class AscendSession : public SessionBasic {
void SetFinalGraphOutput(const ValuePtr &value);
void SetFinalGraphOutput(const VectorRef &vec_output);
void SplitSwitch(KernelGraph *graph) {}
void SplitGraph(const KernelGraphPtr &graph);
void LinkChildGraphs(KernelGraph *graph) {}
void IRFusion(const KernelGraphPtr &graph) {}
void SelectKernelGraphKernel(const KernelGraph &graph) {}
void ConvertPredictModel(const KernelGraphPtr graph) {}
void HardwareOptimizeGraphs(const KernelGraphPtr graph) {}
void RootGraphExecutorValidate(KernelGraph *graph) {}
void RecurseUpdateAllChildGraohOrder(KernelGraph *root_graph);
KernelGraphPtr SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list);
void ChildGraphCommunicationDecrease(std::vector<std::vector<AnfNodePtr>> *anf_node_lists);
// merge execution order list of child graphs
void MergeGraphExecOrder();

View File

@ -16,9 +16,8 @@
#include "session/kernel_graph.h"
#include <algorithm>
#include <queue>
#include <stack>
#include <unordered_set>
#include "common/utils.h"
#include <set>
#include "operator/ops.h"
#include "ir/param_value_py.h"
#include "session/anf_runtime_algorithm.h"
@ -311,9 +310,10 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
auto output_tensor_num = AnfAlgo::GetOutputTensorNum(value_node);
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
// set value node initial device data type = infer data type
std::vector<TypeId> types = std::vector<TypeId>(AnfAlgo::GetOutputTensorNum(value_node), kTypeUnknown);
std::vector<TypeId> types = std::vector<TypeId>(output_tensor_num, kTypeUnknown);
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
@ -584,7 +584,25 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() {
}
}
void KernelGraph::UpdateChildGraphOrder() {}
void KernelGraph::UpdateChildGraphOrder() {
MS_LOG(INFO) << "graph id:" << graph_id_;
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
child_graph_order_.clear();
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
auto call_child_graphs = AnfAlgo ::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
for (const auto &child_graph : call_child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
if (child_graph != parent_graph()) {
child_graph->set_parent_graph(shared_from_this()->cast<std::shared_ptr<KernelGraph>>());
child_graph_order_.push_back(child_graph);
}
}
}
for (size_t i = 0; i < child_graph_order_.size(); i++) {
MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order_[i]->graph_id() << "]";
}
}
std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
@ -601,5 +619,36 @@ std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
}
bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
auto anf_list = TopoSort(get_return());
std::vector<CNodePtr> result;
for (const auto &anf : anf_list) {
if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
result.push_back(anf->cast<CNodePtr>());
}
}
return result;
}
std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) {
MS_EXCEPTION_IF_NULL(parameter);
if (real_inputs_.find(parameter) == real_inputs_.end()) {
return {};
}
return real_inputs_[parameter];
}
void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg) {
MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(arg);
if (real_inputs_.find(parameter) == real_inputs_.end()) {
real_inputs_[parameter] = std::set<AnfNodePtr>();
}
auto &args = real_inputs_[parameter];
(void)args.insert(arg);
}
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
} // namespace session
} // namespace mindspore

View File

@ -23,6 +23,7 @@
#include <string>
#include <queue>
#include <map>
#include <set>
#include <unordered_set>
#include "ir/func_graph.h"
#include "ir/anf.h"
@ -113,6 +114,17 @@ class KernelGraph : public FuncGraph {
}
// get input_tensors pointer of control parameter
std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors() const { return input_ctrl_tensors_; }
// get parent kernel graph
std::shared_ptr<KernelGraph> parent_graph() const { return parent_graph_; }
// set parent kernel graph
void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
// find anf node in graph
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
// get real inputs
std::set<AnfNodePtr> GetRealInput(const AnfNodePtr &parameter);
void SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg);
// used to dump ir
std::string ToString() const override;
private:
// remove value node form graph
@ -158,6 +170,10 @@ class KernelGraph : public FuncGraph {
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
// input_tensors of control parameter
std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_;
// parameter graph
std::shared_ptr<KernelGraph> parent_graph_;
// record real parameters,inputs_ is the formal parameters
std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_;
};
} // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;

View File

@ -247,27 +247,6 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool va
return parameters;
}
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode";
}
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
if (parameters.empty()) {
MS_LOG(EXCEPTION) << "No parameter exist!!";
}
if (parameters.size() == 1) {
return parameters[0];
}
std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
auto make_tuple = graph->NewCNode(make_tuple_input);
MS_EXCEPTION_IF_NULL(make_tuple);
MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
return make_tuple;
}
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
MS_LOG(INFO) << "Load kInputCtrlTensors";
auto inputs_params = graph->input_ctrl_tensors();
@ -390,6 +369,24 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
return new_parameter;
}
AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
if (parameters.empty()) {
MS_LOG(EXCEPTION) << "No parameter exist!!";
}
if (parameters.size() == 1) {
return parameters[0];
}
std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
auto make_tuple = graph->NewCNode(make_tuple_input);
MS_EXCEPTION_IF_NULL(make_tuple);
MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
return make_tuple;
}
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
@ -454,7 +451,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
MS_EXCEPTION_IF_NULL(attr_input);
if (IsValueNode<FuncGraph>(attr_input)) {
// create primitive of cnode:call
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))};
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
// create a ValueNode<KernelGraph> as input of cnode:call
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
@ -466,12 +463,10 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
}
} else if (attr_input->isa<CNode>()) {
// create primitive of cnode:call(switch)
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))};
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
auto prim = GetCNodePrimitive(cnode_input);
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() != kSwitchOpName) {
if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
}
cnode_inputs.emplace_back(cnode_input);
@ -484,7 +479,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
// push attr to inputs[0] of new cnode
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
}
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
@ -545,7 +540,6 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
graph->FrontBackendlMapAdd(anf, new_value_node);
graph->AddValueNodeToGraph(new_value_node);
return new_value_node;
}
@ -555,11 +549,11 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
TraceManager::EndTrace();
graph_inputs->push_back(new_parameter);
graph->FrontBackendlMapAdd(anf, new_parameter);

View File

@ -114,6 +114,7 @@ class SessionBasic {
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;