review code second

This commit is contained in:
Margaret_wangrui 2020-06-24 18:00:47 +08:00
parent 927278be44
commit 3ea75e0d98
12 changed files with 136 additions and 95 deletions

View File

@ -162,7 +162,7 @@ void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &inpu
UpdateRefNodeOutputMem(graph);
}
void KernelRuntime::RunOpClearMemory(session::KernelGraph *graph) {
void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
// clear input parameter memory resource
for (const auto &input_node : graph->inputs()) {

View File

@ -54,7 +54,7 @@ class KernelRuntime {
virtual bool Init() = 0;
virtual void AssignMemory(session::KernelGraph *graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph);
void RunOpClearMemory(session::KernelGraph *graph);
void RunOpClearMemory(const session::KernelGraph *graph);
virtual bool Run(session::KernelGraph *graph);
virtual bool DumpData(session::KernelGraph *graph);
virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger);

View File

@ -303,7 +303,7 @@ size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(cnode);
size_t input_num = cnode->inputs().size();
if (input_num == 0) {
MS_LOG(EXCEPTION) << "cnode inputs size can't be zero";
MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero";
}
// exclude intputs[0],which is value_node storing attr,inputs left are real input
return input_num - 1;
@ -994,10 +994,10 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node)
}
std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) {
if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node.";
}
MS_EXCEPTION_IF_NULL(call_node);
if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node.";
}
auto input1 = call_node->input(1);
MS_EXCEPTION_IF_NULL(input1);
if (input1->isa<ValueNode>()) {
@ -1009,7 +1009,7 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
auto switch_node = input1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_node);
auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr {
auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr {
auto partial = switch_node->input(input_index);
MS_EXCEPTION_IF_NULL(partial);
auto partial_cnode = partial->cast<CNodePtr>();
@ -1031,7 +1031,7 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
MS_EXCEPTION_IF_NULL(call_node);
if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString();
MS_LOG(EXCEPTION) << "Call node should be a 'call', but is a " << call_node->DebugString();
}
auto input1 = call_node->input(1);
if (input1->isa<ValueNode>()) {

View File

@ -14,9 +14,9 @@
* limitations under the License.
*/
#include "session/ascend_control_parser.h"
#include <utility>
#include <memory>
#include "session/ascend_control_parser.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/union_find_set.h"
@ -96,7 +96,7 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
const std::set<AnfNodePtr> &parameter_reuse_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (parameter_reuse_set.empty()) {
MS_LOG(EXCEPTION) << "parameter_reuse_set is empty.";
MS_LOG(EXCEPTION) << "Parameter_reuse_set is empty.";
}
if (memo->find(kg.get()) != memo->end()) {
return;
@ -155,6 +155,7 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
std::map<uint32_t, KernelGraphPtr> graph_id_map;
for (auto &g : memo) {
MS_EXCEPTION_IF_NULL(g);
if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) {
MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id()
<< ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString();
@ -206,6 +207,20 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
}
}
NotNull<CNodePtr> AscendControlParser::GetStartLabel(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label) {
CNodePtr start_label;
if (last_node != nullptr && last_label != nullptr) {
start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString();
kg->set_start_label(start_label);
} else {
// no goto node will jump to start label of root graph, so return a fake label
start_label = std::make_shared<CNode>(std::vector<AnfNodePtr>(), FuncGraphPtr(nullptr));
}
return NOT_NULL(start_label);
}
NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label,
const NotNull<std::set<KernelGraphPtr> *> memo) {
@ -225,28 +240,22 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
kg->SetExecOrderByDefault();
const std::vector<CNodePtr> &nodes = kg->execution_order();
// 4. insert first_label
CNodePtr start_label;
if (last_node != nullptr && last_label != nullptr) {
start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString();
kg->set_start_label(start_label);
} else {
// no goto node will jump to start label of root graph, so return a fake label
start_label = std::make_shared<CNode>(std::vector<AnfNodePtr>(), FuncGraphPtr(nullptr));
}
CNodePtr start_label = GetStartLabel(kg, last_node, last_label);
// 5. traverse
for (size_t i = 0; i < nodes.size(); ++i) {
auto &cnode = nodes[i];
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() < kCNodePrim + 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex);
if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
MS_LOG(DEBUG) << "continue node " << cnode->DebugString();
MS_LOG(DEBUG) << "Continue node " << cnode->DebugString();
continue;
}
AnfNodePtr arg = cnode->input(kFirstDataInputIndex);
MS_EXCEPTION_IF_NULL(arg);
if (IsValueNode<KernelGraph>(arg)) {
RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
} else if (!arg->isa<CNode>()) {
@ -293,6 +302,7 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
if (from_graph_call_node != nullptr && last_label != nullptr) {
auto label_goto =
kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label});
MS_EXCEPTION_IF_NULL(label_goto);
MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString();
kg->set_end_goto(label_goto);
}
@ -342,6 +352,7 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
}
// 1 return label
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
MS_EXCEPTION_IF_NULL(back_label);
MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node "
<< cur_node->DebugString();
// 2 add depend relationship
@ -351,6 +362,9 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
}
// 3 recurse sub graph
const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
if (kCNodeSwitchCond >= origin_switch_inputs.size()) {
MS_LOG(EXCEPTION) << "The size of origin_switch_inputs is not more than " << kCNodeSwitchCond;
}
std::vector<AnfNodePtr> new_switch_inputs = {
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
origin_switch_inputs[kCNodeSwitchCond]};
@ -538,6 +552,8 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
auto start_label_set = child_graph->get_start_label();
uint32_t start_label_set_index = AnfAlgo::GetNodeAttr<uint32_t>(start_label_set, kAttrLabelIndex);
if (label_index != start_label_set_index) {
MS_EXCEPTION_IF_NULL(cur_label);
MS_EXCEPTION_IF_NULL(start_label_set);
MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString()
<< " index " << start_label_set_index << " current child graph order : " << order_index;
return false;
@ -563,7 +579,7 @@ void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
}
}
for (size_t i = 0; i < child_graph_order.size(); i++) {
MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
}
kg->set_child_graph_order(child_graph_order);
}

View File

@ -38,6 +38,8 @@ class AscendControlParser {
static void UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg);
private:
static NotNull<CNodePtr> GetStartLabel(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label);
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label,
const NotNull<std::set<KernelGraphPtr> *> memo);

View File

@ -139,16 +139,16 @@ std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &ar
}
}
if (graph_inputs.size() != valid_inputs.size()) {
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
MS_LOG(EXCEPTION) << "Graph_inputs.size(): " << graph_inputs.size()
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
}
if (real_args_size != graph_inputs.size()) {
for (size_t j = 0; j < valid_inputs.size(); j++) {
if (valid_inputs[j]) {
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
MS_LOG(INFO) << "Index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
}
}
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
MS_LOG(WARNING) << "Real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
<< " not equal";
}
return real_args;
@ -158,7 +158,7 @@ std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
std::vector<CNodePtr> cnodes = {};
size_t i = 0;
for (const auto &anf : anf_nodes) {
MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString();
MS_LOG(INFO) << "Apply_list[" << i++ << "] = " << anf->DebugString();
MS_EXCEPTION_IF_NULL(anf);
if (anf->isa<CNode>()) {
cnodes.push_back(anf->cast<CNodePtr>());
@ -276,7 +276,7 @@ static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
} // namespace
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
MS_LOG(INFO) << "start";
MS_LOG(INFO) << "Start";
// construct graph, if successfully, graph_sum_ + 1
auto graph = ConstructKernelGraph(lst, outputs);
auto graph_id = graph->graph_id();
@ -285,7 +285,7 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
}
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start";
MS_LOG(INFO) << "Start";
std::vector<KernelGraphPtr> all_graphs;
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
BackendOptimization(all_graphs);
@ -338,7 +338,7 @@ void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph>
}
void AscendSession::BuildGraph(GraphId graph_id) {
MS_LOG(INFO) << "start";
MS_LOG(INFO) << "Start";
auto graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(graph);
// resource initialize
@ -369,6 +369,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
MergeGraphExecOrder();
} else {
auto single_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(single_graph);
CompileChildGraph(single_graph);
// set the distinction label of single graph
single_graph->set_stream_distinction_label(graph_id);
@ -397,7 +398,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
// sync the inital const tensor to device
SyncInitialTenosrToDevice();
ExportChildGraphs(graph_id);
MS_LOG(INFO) << "end";
MS_LOG(INFO) << "End";
}
void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
@ -439,7 +440,7 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *const outputs) {
MS_LOG(INFO) << "start";
MS_LOG(INFO) << "Start";
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
// if none of child graph and no anf output exists
@ -499,7 +500,7 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get());
if (!ret_ok) {
MS_LOG(EXCEPTION) << "run task error!";
MS_LOG(EXCEPTION) << "Run task error!";
}
MS_LOG(INFO) << "Finish!";
}
@ -696,7 +697,7 @@ void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input
MS_LOG(INFO) << "Finish!";
}
void AscendSession::RunOpMemoryClear(KernelGraph *kernel_graph) const {
void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
@ -796,6 +797,7 @@ void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph)
GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
MS_LOG(INFO) << "Start! Args size " << args.size();
auto final_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(final_graph);
final_graph_id_ = final_graph->graph_id();
MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success";
// init private variables and bind them with final_graph_id
@ -826,7 +828,7 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
final_graph_inputs->push_back(parameter_backend);
}
MS_EXCEPTION_IF_NULL(parameter_backend);
MS_LOG(INFO) << "parameter backend " << parameter_backend->DebugString() << " belong_graph_id "
MS_LOG(INFO) << "Parameter backend " << parameter_backend->DebugString() << " belong_graph_id "
<< AnfAlgo::GetGraphId(parameter_backend.get());
}
MS_LOG(INFO) << "End final_graph_id " << final_graph_id_;
@ -842,7 +844,7 @@ void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph,
if (graph_order_iter == graph_execute_orders_.end()) {
SessionBasic::GetSummaryNodes(graph);
auto summary_nodes = graph->summary_nodes();
(*summary).insert(summary_nodes.begin(), summary_nodes.end());
summary->insert(summary_nodes.begin(), summary_nodes.end());
return;
}
// for every child graph, find summary nodes
@ -854,7 +856,7 @@ void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph,
}
SessionBasic::GetSummaryNodes(child_graph.get());
auto child_graph_summary = child_graph->summary_nodes();
(*summary).insert(child_graph_summary.begin(), child_graph_summary.end());
summary->insert(child_graph_summary.begin(), child_graph_summary.end());
RecurseGetSummaryNodes(child_graph.get(), summary);
}
graph->set_summary_nodes(*summary);
@ -941,6 +943,7 @@ void AscendSession::SetFinalGraphOutput(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(final_graph);
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
final_graph->set_executable(false);
MS_EXCEPTION_IF_NULL(value);
MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]";
}
@ -1047,6 +1050,7 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) {
MS_LOG(INFO) << "The last graph of false branch is " << false_graph_id;
// create fake output
auto fake_output_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(fake_output_graph);
graph_execute_order.push_back(fake_output_graph->graph_id());
graph_order_type.push_back(COMMON_GRAPH);
fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output()));
@ -1079,7 +1083,7 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id,
auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id);
auto &graph_order_type = GetGraphOrderType(final_graph_id_);
if (cond_graph_index >= graph_order_type.size()) {
MS_LOG(EXCEPTION) << "cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size();
MS_LOG(EXCEPTION) << "Cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size();
}
graph_order_type[cond_graph_index] = CONDITION_GRAPH;
// update distinction label of false graph,update before merge to sure the distinction
@ -1156,7 +1160,7 @@ void AscendSession::InsertAllAssigns() {
MS_EXCEPTION_IF_NULL(to_graph);
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
if (input_idx >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
}
auto backend_parameter = graph_inputs[input_idx];
assigns.emplace_back(std::pair<AnfNodePtr, AnfNodePtr>(front_anf, backend_parameter));
@ -1180,7 +1184,7 @@ void AscendSession::InsertAllAssigns() {
// insert active to graph
void AscendSession::SetActive(GraphId from, GraphId to) {
if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) {
MS_LOG(WARNING) << " to " << to << " has been exits in map,from " << from << ",exist from "
MS_LOG(WARNING) << "To " << to << " has been exits in map,from " << from << ",exist from "
<< while_condition_graphs_[to];
return;
}
@ -1216,7 +1220,7 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId
MS_EXCEPTION_IF_NULL(to_graph);
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
if (input_idx >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
}
auto backend_parameter = graph_inputs[input_idx];
MS_EXCEPTION_IF_NULL(backend_parameter);
@ -1250,7 +1254,7 @@ void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor
}
void AscendSession::UpdateGraphOrder(GraphId to_graph_id) {
MS_LOG(INFO) << "to_graph_id " << to_graph_id;
MS_LOG(INFO) << "To_graph_id " << to_graph_id;
auto &graph_order = GetGraphOrder(final_graph_id_);
auto &graph_type = GetGraphOrderType(final_graph_id_);
for (size_t i = 0; i < graph_order.size(); i++) {
@ -1268,6 +1272,8 @@ void AscendSession::UpdateGraphOrder(GraphId to_graph_id) {
}
size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto output_num = AnfAlgo::GetOutputTensorNum(node);
if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
return input_index + output_num;
@ -1282,6 +1288,7 @@ size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfN
}
size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(value);
if (!value->isa<Tensor>()) {
MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString();
@ -1319,7 +1326,7 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
size_t input_index = 0;
for (size_t i = 0; i < real_args.size(); i++) {
if (input_index >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
MS_LOG(EXCEPTION) << "Input_index " << input_index << " out of range size " << graph_inputs.size();
}
auto &real_arg = real_args[i];
if (utils::isa<AnfNodePtr>(real_arg)) {
@ -1351,7 +1358,7 @@ GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
}
}
MS_EXCEPTION_IF_NULL(front_anf);
MS_LOG(DEBUG) << "front_anf " << front_anf->DebugString() << " is not exist in any graph";
MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
return kInvalidGraphId;
}
@ -1514,7 +1521,7 @@ void AscendSession::SyncInitialTenosrToDevice() {
MS_EXCEPTION_IF_NULL(to_graph);
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
if (input_idx >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
}
auto backend_parameter = graph_inputs[input_idx];
// sync data from host to device
@ -1548,6 +1555,7 @@ static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph,
new_kernel_graph->set_return(anf_node);
}
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_LOG(INFO) << "Output[" << output_idx++ << "]:" << anf_node->DebugString();
make_tuple_inputs.push_back(anf_node);
}
@ -1560,7 +1568,7 @@ static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph,
std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
const std::vector<CNodePtr> &list) {
MS_EXCEPTION_IF_NULL(new_kernel_graph);
MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id();
MS_LOG(INFO) << "Start contruct splited kernel graph:" << new_kernel_graph->graph_id();
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
std::vector<AnfNodePtr> call_node_inputs;
std::vector<AnfNodePtr> new_graph_inputs;
@ -1604,7 +1612,7 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id();
ConstructSplitedGraphOutput(new_kernel_graph, list);
MS_LOG(INFO) << "end";
MS_LOG(INFO) << "End";
return call_node_inputs;
}
@ -1699,7 +1707,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
}
AscendControlParser::UpdateChildGraphOrder(graph);
UpdateRealInput(graph);
MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end";
MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end";
// recurse to split child graph
}

View File

@ -81,7 +81,7 @@ class AscendSession : public SessionBasic {
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void MemoryAlloc(KernelGraph *kernel_graph) const;
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
void RunOpMemoryClear(KernelGraph *kernel_graph) const;
void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
void GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void ExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;

View File

@ -33,6 +33,7 @@ namespace mindspore {
namespace session {
ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}

View File

@ -33,6 +33,7 @@ constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(que);
MS_EXCEPTION_IF_NULL(visited_nodes);
if (visited_nodes->find(node) == visited_nodes->end()) {
@ -131,11 +132,10 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
std::vector<AnfNodePtr> active_nodes;
for (const auto &output_edge : it->second) {
auto next_node = output_edge.first;
MS_EXCEPTION_IF_NULL(next_node);
if (node_input_num_.find(next_node) == node_input_num_.end()) {
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
}
MS_EXCEPTION_IF_NULL(next_node);
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
if (node_input_num_[next_node] < output_edge.second) {
@ -147,7 +147,7 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
(void)visited_nodes->insert(next_node);
if (AnfAlgo::IsCommunicationOp(next_node)) {
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
visit_queue->push(next_node);
} else {
active_nodes.emplace_back(next_node);
@ -156,7 +156,8 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
}
for (auto &node : active_nodes) {
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Visit node:" << node->DebugString();
visit_queue->push(node);
}
}
@ -380,12 +381,12 @@ std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNo
auto output_size = AnfAlgo::GetOutputTensorNum(value_node);
std::vector<AnfNodePtr> convert_inputs;
if (!node_value->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
}
auto value_tuple = node_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
if (value_tuple->size() != output_size) {
MS_LOG(EXCEPTION) << "value tuple size" << value_tuple->size()
MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size()
<< " is not mathced with the value node's output size" << output_size;
}
for (size_t index = 0; index < value_tuple->value().size(); ++index) {
@ -408,7 +409,7 @@ std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNo
convert_inputs.emplace_back(new_value_node);
}
if (!RemoveValueNodeFromGraph(value_node)) {
MS_LOG(WARNING) << "failed to remove the value_node " << value_node->DebugString();
MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
}
return convert_inputs;
}
@ -429,10 +430,10 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode
MS_EXCEPTION_IF_NULL(front_anf);
MS_EXCEPTION_IF_NULL(backend_anf);
if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
MS_LOG(EXCEPTION) << "anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
}
if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
MS_LOG(EXCEPTION) << "kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
}
front_backend_anf_map_[front_anf] = backend_anf;
backend_front_anf_map_[backend_anf] = front_anf;
@ -442,15 +443,15 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
MS_EXCEPTION_IF_NULL(old_backend_anf);
MS_EXCEPTION_IF_NULL(new_backend_anf);
if (old_backend_anf == new_backend_anf) {
MS_LOG(DEBUG) << "old same with new:" << old_backend_anf->DebugString();
MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString();
return;
}
if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
return;
}
if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
MS_LOG(EXCEPTION) << "anf is not exist in the map ,old " << old_backend_anf->DebugString();
MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString();
}
front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
@ -483,6 +484,8 @@ void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const V
}
void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(input);
MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num;
auto output_depend_edge = std::pair<AnfNodePtr, size_t>(node, depend_edge_num);
// add output depend edge of input
@ -571,6 +574,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(que);
MS_EXCEPTION_IF_NULL(visited_nodes);
if (!node->isa<CNode>()) {
return false;
}
@ -596,6 +601,7 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf
}
void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
MS_EXCEPTION_IF_NULL(seed_nodes);
node_output_edges_.clear();
node_input_num_.clear();
node_input_edges_.clear();
@ -637,14 +643,14 @@ bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return r
AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
if (!IsInRefOutputMap(out_pair)) {
MS_LOG(EXCEPTION) << "out_pair is not in RefOutputMap";
MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap";
}
return ref_out_in_map_.at(out_pair);
}
void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
if (IsInRefOutputMap(final_pair)) {
MS_LOG(EXCEPTION) << "out_pair is already in RefOutputMap";
MS_LOG(EXCEPTION) << "Out_pair is already in RefOutputMap";
}
(void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair));
}
@ -750,7 +756,7 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg) {
MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(arg);
MS_LOG(INFO) << "parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(arg);
auto iter = std::find_if(
@ -783,7 +789,7 @@ void KernelGraph::UpdateCallRealInput() {
}
void KernelGraph::PrintGraphExecuteOrder() const {
MS_LOG(INFO) << "graph:" << graph_id_ << "execution order";
MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order";
for (size_t i = 0; i < execution_order_.size(); i++) {
CNodePtr cur_cnode_ptr = execution_order_[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
@ -805,7 +811,7 @@ void KernelGraph::PrintGraphExecuteOrder() const {
}
}
MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"
<< event_str << label_str;

View File

@ -75,7 +75,7 @@ class KernelGraph : public FuncGraph {
// add value node tensor relation map
void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node);
// get all value nodes of graph
std::unordered_set<ValueNodePtr> graph_value_nodes() { return graph_value_nodes_; }
const std::unordered_set<ValueNodePtr> graph_value_nodes() const { return graph_value_nodes_; }
// add value node to graph
void AddValueNodeToGraph(const ValueNodePtr &value_node);
// ref output is in map

View File

@ -60,7 +60,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(INFO) << "create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
// if node is a value node, no need sync addr from device to host
if (!AnfAlgo::OutputAddrExist(node, output_index)) {
if (node->isa<ValueNode>()) {
@ -71,13 +71,13 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
if (input_idx >= input_tensors.size()) {
MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size();
MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
}
if (graph.inputs()[input_idx] == node) {
return input_tensors[input_idx];
}
}
MS_LOG(EXCEPTION) << "parameter : " << node->DebugString() << "has no output addr";
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr";
}
}
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
@ -97,7 +97,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
tensor->set_dirty(false);
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
MS_LOG(INFO) << "output sync device to host error!!!";
MS_LOG(INFO) << "Output sync device to host error!!!";
tensor->set_dirty(false);
}
return tensor;
@ -106,10 +106,10 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(anf);
MS_LOG(INFO) << "create tensor for output[" << anf->DebugString() << "]";
MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
MS_EXCEPTION_IF_NULL(item_with_index.first);
MS_LOG(INFO) << "create tensor for output after visit:" << item_with_index.first->DebugString();
MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
// special handle for maketuple
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
auto cnode = item_with_index.first->cast<CNodePtr>();
@ -133,7 +133,7 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(anf);
if (!AnfAlgo::IsRealKernel(anf)) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] should be a executable kernel";
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel";
}
if (anf->isa<ValueNode>()) {
return CreateOneTensor(anf, 0, graph, input_tensors);
@ -194,7 +194,7 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool va
if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
create_parameter((*tuple_abstract)[output_idx]);
}
@ -207,6 +207,7 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool va
}
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Load kInputCtrlTensors";
auto inputs_params = graph->input_ctrl_tensors();
if (inputs_params == nullptr) {
@ -274,10 +275,10 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
}
void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
MS_LOG(INFO) << "graph outputs:";
MS_LOG(INFO) << "Graph outputs:";
const size_t max_deep = 10;
if (recurse_level > max_deep) {
MS_LOG(INFO) << "recurse too deep";
MS_LOG(INFO) << "Recurse too deep";
return;
}
std::string tab_str;
@ -318,7 +319,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
}
MS_EXCEPTION_IF_NULL(graph);
auto m_tensor = GetParamDefaultInputTensor(anf);
@ -523,7 +524,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
}
auto m_tensor = GetParamDefaultInputTensor(anf);
@ -593,6 +594,22 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
return graph;
}
void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// create a new cnode object
auto new_cnode = CreateNewCNode(cnode, graph.get());
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
new_cnode->set_scope(cnode->scope());
graph->FrontBackendlMapAdd(node, new_cnode);
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
graph->set_return(new_cnode);
}
}
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
std::vector<KernelGraphPtr> *all_out_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -630,18 +647,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
}
continue;
} else {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// create a new cnode object
auto new_cnode = CreateNewCNode(cnode, graph.get());
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
new_cnode->set_scope(cnode->scope());
graph->FrontBackendlMapAdd(node, new_cnode);
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
graph->set_return(new_cnode);
}
CreateCNodeKernelGraph(node, graph);
}
}
// if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
@ -670,7 +676,7 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &para
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
continue;
}
MS_LOG(INFO) << "graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
graph_inputs->push_back(backend_parameter);
}
}
@ -686,7 +692,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
}
auto input_nodes = kernel_graph->inputs();
if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
<< ", input_ctrl_size:" << input_ctrl_size;
}
auto ms_context = MsContext::GetInstance();
@ -740,7 +746,7 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
continue;
@ -771,7 +777,7 @@ void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
auto cnode = n->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() <= kSummaryGetItem) {
MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!";
}
auto node = cnode->input(kSummaryGetItem);
MS_EXCEPTION_IF_NULL(node);
@ -829,7 +835,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
std::vector<AnfNodePtr> output_args;
for (const auto &output : outputs) {
MS_EXCEPTION_IF_NULL(output);
MS_LOG(INFO) << "output:" << output->DebugString();
MS_LOG(INFO) << "Output:" << output->DebugString();
}
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
@ -848,13 +854,13 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr
MS_LOG(INFO) << "Start!";
std::vector<AnfNodePtr> make_tuple_inputs;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
MS_EXCEPTION_IF_NULL(graph);
if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
auto idx = NewValueNode(SizeToInt(output_index));
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int32Imm>(output_index);
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
MS_EXCEPTION_IF_NULL(graph);
auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};

View File

@ -77,6 +77,8 @@ class SessionBasic {
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
void CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
std::vector<KernelGraphPtr> *all_out_graph);