forked from mindspore-Ecosystem/mindspore
!2625 Fix code review problems of session
Merge pull request !2625 from Margaret_wangrui/review_code_second
This commit is contained in:
commit
9b326fe8c5
|
@ -163,7 +163,7 @@ void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &inpu
|
|||
UpdateRefNodeOutputMem(graph);
|
||||
}
|
||||
|
||||
void KernelRuntime::RunOpClearMemory(session::KernelGraph *graph) {
|
||||
void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// clear input parameter memory resource
|
||||
for (const auto &input_node : graph->inputs()) {
|
||||
|
|
|
@ -54,7 +54,7 @@ class KernelRuntime {
|
|||
virtual bool Init() = 0;
|
||||
virtual void AssignMemory(session::KernelGraph *graph);
|
||||
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph);
|
||||
void RunOpClearMemory(session::KernelGraph *graph);
|
||||
void RunOpClearMemory(const session::KernelGraph *graph);
|
||||
virtual bool Run(session::KernelGraph *graph);
|
||||
virtual bool DumpData(session::KernelGraph *graph);
|
||||
virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger);
|
||||
|
|
|
@ -303,7 +303,7 @@ size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t input_num = cnode->inputs().size();
|
||||
if (input_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "cnode inputs size can't be zero";
|
||||
MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero";
|
||||
}
|
||||
// exclude intputs[0],which is value_node storing attr,inputs left are real input
|
||||
return input_num - 1;
|
||||
|
@ -994,10 +994,10 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node)
|
|||
}
|
||||
|
||||
std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
|
||||
MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
|
||||
MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node.";
|
||||
}
|
||||
auto input1 = call_node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(input1);
|
||||
if (input1->isa<ValueNode>()) {
|
||||
|
@ -1009,7 +1009,7 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
|
|||
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
|
||||
auto switch_node = input1->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_node);
|
||||
auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr {
|
||||
auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr {
|
||||
auto partial = switch_node->input(input_index);
|
||||
MS_EXCEPTION_IF_NULL(partial);
|
||||
auto partial_cnode = partial->cast<CNodePtr>();
|
||||
|
@ -1031,7 +1031,7 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
|
|||
bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
|
||||
MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Call node should be a 'call', but is a " << call_node->DebugString();
|
||||
}
|
||||
auto input1 = call_node->input(1);
|
||||
if (input1->isa<ValueNode>()) {
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "session/ascend_control_parser.h"
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "session/ascend_control_parser.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "utils/union_find_set.h"
|
||||
|
||||
|
@ -111,7 +111,7 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
|
|||
const std::set<AnfNodePtr> ¶meter_reuse_set,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
if (parameter_reuse_set.empty()) {
|
||||
MS_LOG(EXCEPTION) << "parameter_reuse_set is empty.";
|
||||
MS_LOG(EXCEPTION) << "Parameter_reuse_set is empty.";
|
||||
}
|
||||
if (memo->find(kg.get()) != memo->end()) {
|
||||
return;
|
||||
|
@ -170,6 +170,7 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
|
|||
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
|
||||
std::map<uint32_t, KernelGraphPtr> graph_id_map;
|
||||
for (auto &g : memo) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id()
|
||||
<< ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString();
|
||||
|
@ -222,6 +223,20 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
|
|||
}
|
||||
}
|
||||
|
||||
NotNull<CNodePtr> AscendControlParser::GetStartLabel(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
|
||||
const CNodePtr &last_label) {
|
||||
CNodePtr start_label;
|
||||
if (last_node != nullptr && last_label != nullptr) {
|
||||
start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
|
||||
MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString();
|
||||
kg->set_start_label(start_label);
|
||||
} else {
|
||||
// no goto node will jump to start label of root graph, so return a fake label
|
||||
start_label = std::make_shared<CNode>(std::vector<AnfNodePtr>(), FuncGraphPtr(nullptr));
|
||||
}
|
||||
return NOT_NULL(start_label);
|
||||
}
|
||||
|
||||
NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
|
||||
const CNodePtr &last_label,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
|
@ -241,28 +256,22 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
|
|||
kg->SetExecOrderByDefault();
|
||||
const std::vector<CNodePtr> &nodes = kg->execution_order();
|
||||
// 4. insert first_label
|
||||
CNodePtr start_label;
|
||||
if (last_node != nullptr && last_label != nullptr) {
|
||||
start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
|
||||
MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString();
|
||||
kg->set_start_label(start_label);
|
||||
} else {
|
||||
// no goto node will jump to start label of root graph, so return a fake label
|
||||
start_label = std::make_shared<CNode>(std::vector<AnfNodePtr>(), FuncGraphPtr(nullptr));
|
||||
}
|
||||
CNodePtr start_label = GetStartLabel(kg, last_node, last_label);
|
||||
|
||||
// 5. traverse
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
auto &cnode = nodes[i];
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() < kCNodePrim + 1) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||
}
|
||||
AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex);
|
||||
if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
|
||||
MS_LOG(DEBUG) << "continue node " << cnode->DebugString();
|
||||
MS_LOG(DEBUG) << "Continue node " << cnode->DebugString();
|
||||
continue;
|
||||
}
|
||||
AnfNodePtr arg = cnode->input(kFirstDataInputIndex);
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (IsValueNode<KernelGraph>(arg)) {
|
||||
RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
|
||||
} else if (!arg->isa<CNode>()) {
|
||||
|
@ -309,6 +318,7 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
|
|||
if (from_graph_call_node != nullptr && last_label != nullptr) {
|
||||
auto label_goto =
|
||||
kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label});
|
||||
MS_EXCEPTION_IF_NULL(label_goto);
|
||||
MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString();
|
||||
kg->set_end_goto(label_goto);
|
||||
}
|
||||
|
@ -358,6 +368,7 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
|||
}
|
||||
// 1 return label
|
||||
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
|
||||
MS_EXCEPTION_IF_NULL(back_label);
|
||||
MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node "
|
||||
<< cur_node->DebugString();
|
||||
// 2 add depend relationship
|
||||
|
@ -367,6 +378,9 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
|||
}
|
||||
// 3 recurse sub graph
|
||||
const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
|
||||
if (kCNodeSwitchCond >= origin_switch_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of origin_switch_inputs is not more than " << kCNodeSwitchCond;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_switch_inputs = {
|
||||
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
|
||||
origin_switch_inputs[kCNodeSwitchCond]};
|
||||
|
@ -562,6 +576,8 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
|
|||
auto start_label_set = child_graph->get_start_label();
|
||||
uint32_t start_label_set_index = AnfAlgo::GetNodeAttr<uint32_t>(start_label_set, kAttrLabelIndex);
|
||||
if (label_index != start_label_set_index) {
|
||||
MS_EXCEPTION_IF_NULL(cur_label);
|
||||
MS_EXCEPTION_IF_NULL(start_label_set);
|
||||
MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString()
|
||||
<< " index " << start_label_set_index << " current child graph order : " << order_index;
|
||||
return false;
|
||||
|
@ -587,7 +603,7 @@ void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
|
|||
}
|
||||
}
|
||||
for (size_t i = 0; i < child_graph_order.size(); i++) {
|
||||
MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
|
||||
MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
|
||||
}
|
||||
kg->set_child_graph_order(child_graph_order);
|
||||
}
|
||||
|
|
|
@ -38,6 +38,8 @@ class AscendControlParser {
|
|||
static void UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg);
|
||||
|
||||
private:
|
||||
static NotNull<CNodePtr> GetStartLabel(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
|
||||
const CNodePtr &last_label);
|
||||
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
|
||||
const CNodePtr &last_label,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
|
|
|
@ -139,16 +139,16 @@ std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &ar
|
|||
}
|
||||
}
|
||||
if (graph_inputs.size() != valid_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
|
||||
MS_LOG(EXCEPTION) << "Graph_inputs.size(): " << graph_inputs.size()
|
||||
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
|
||||
}
|
||||
if (real_args_size != graph_inputs.size()) {
|
||||
for (size_t j = 0; j < valid_inputs.size(); j++) {
|
||||
if (valid_inputs[j]) {
|
||||
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
|
||||
MS_LOG(INFO) << "Index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
|
||||
}
|
||||
}
|
||||
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
|
||||
MS_LOG(WARNING) << "Real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
|
||||
<< " not equal";
|
||||
}
|
||||
return real_args;
|
||||
|
@ -158,7 +158,7 @@ std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
|
|||
std::vector<CNodePtr> cnodes = {};
|
||||
size_t i = 0;
|
||||
for (const auto &anf : anf_nodes) {
|
||||
MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString();
|
||||
MS_LOG(INFO) << "Apply_list[" << i++ << "] = " << anf->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (anf->isa<CNode>()) {
|
||||
cnodes.push_back(anf->cast<CNodePtr>());
|
||||
|
@ -276,7 +276,7 @@ static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
|
|||
} // namespace
|
||||
|
||||
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_LOG(INFO) << "Start";
|
||||
// construct graph, if successfully, graph_sum_ + 1
|
||||
auto graph = ConstructKernelGraph(lst, outputs);
|
||||
auto graph_id = graph->graph_id();
|
||||
|
@ -285,7 +285,7 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
|
|||
}
|
||||
|
||||
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_LOG(INFO) << "Start";
|
||||
std::vector<KernelGraphPtr> all_graphs;
|
||||
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
||||
BackendOptimization(all_graphs);
|
||||
|
@ -338,7 +338,7 @@ void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph>
|
|||
}
|
||||
|
||||
void AscendSession::BuildGraph(GraphId graph_id) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_LOG(INFO) << "Start";
|
||||
auto graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// resource initialize
|
||||
|
@ -369,6 +369,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
|
|||
MergeGraphExecOrder();
|
||||
} else {
|
||||
auto single_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(single_graph);
|
||||
CompileChildGraph(single_graph);
|
||||
// set the distinction label of single graph
|
||||
single_graph->set_stream_distinction_label(graph_id);
|
||||
|
@ -397,7 +398,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
|
|||
// sync the inital const tensor to device
|
||||
SyncInitialTenosrToDevice();
|
||||
ExportChildGraphs(graph_id);
|
||||
MS_LOG(INFO) << "end";
|
||||
MS_LOG(INFO) << "End";
|
||||
}
|
||||
|
||||
void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
|
||||
|
@ -439,7 +440,7 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
|
|||
|
||||
void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *const outputs) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_LOG(INFO) << "Start";
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
// if none of child graph and no anf output exists
|
||||
|
@ -499,7 +500,7 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get());
|
||||
if (!ret_ok) {
|
||||
MS_LOG(EXCEPTION) << "run task error!";
|
||||
MS_LOG(EXCEPTION) << "Run task error!";
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
@ -696,7 +697,7 @@ void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input
|
|||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::RunOpMemoryClear(KernelGraph *kernel_graph) const {
|
||||
void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
|
@ -796,6 +797,7 @@ void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph)
|
|||
GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
|
||||
MS_LOG(INFO) << "Start! Args size " << args.size();
|
||||
auto final_graph = NewKernelGraph();
|
||||
MS_EXCEPTION_IF_NULL(final_graph);
|
||||
final_graph_id_ = final_graph->graph_id();
|
||||
MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success";
|
||||
// init private variables and bind them with final_graph_id
|
||||
|
@ -826,7 +828,7 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
|
|||
final_graph_inputs->push_back(parameter_backend);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(parameter_backend);
|
||||
MS_LOG(INFO) << "parameter backend " << parameter_backend->DebugString() << " belong_graph_id "
|
||||
MS_LOG(INFO) << "Parameter backend " << parameter_backend->DebugString() << " belong_graph_id "
|
||||
<< AnfAlgo::GetGraphId(parameter_backend.get());
|
||||
}
|
||||
MS_LOG(INFO) << "End final_graph_id " << final_graph_id_;
|
||||
|
@ -842,7 +844,7 @@ void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph,
|
|||
if (graph_order_iter == graph_execute_orders_.end()) {
|
||||
SessionBasic::GetSummaryNodes(graph);
|
||||
auto summary_nodes = graph->summary_nodes();
|
||||
(*summary).insert(summary_nodes.begin(), summary_nodes.end());
|
||||
summary->insert(summary_nodes.begin(), summary_nodes.end());
|
||||
return;
|
||||
}
|
||||
// for every child graph, find summary nodes
|
||||
|
@ -854,7 +856,7 @@ void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph,
|
|||
}
|
||||
SessionBasic::GetSummaryNodes(child_graph.get());
|
||||
auto child_graph_summary = child_graph->summary_nodes();
|
||||
(*summary).insert(child_graph_summary.begin(), child_graph_summary.end());
|
||||
summary->insert(child_graph_summary.begin(), child_graph_summary.end());
|
||||
RecurseGetSummaryNodes(child_graph.get(), summary);
|
||||
}
|
||||
graph->set_summary_nodes(*summary);
|
||||
|
@ -941,6 +943,7 @@ void AscendSession::SetFinalGraphOutput(const ValuePtr &value) {
|
|||
MS_EXCEPTION_IF_NULL(final_graph);
|
||||
final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node}));
|
||||
final_graph->set_executable(false);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]";
|
||||
}
|
||||
|
||||
|
@ -1047,6 +1050,7 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) {
|
|||
MS_LOG(INFO) << "The last graph of false branch is " << false_graph_id;
|
||||
// create fake output
|
||||
auto fake_output_graph = NewKernelGraph();
|
||||
MS_EXCEPTION_IF_NULL(fake_output_graph);
|
||||
graph_execute_order.push_back(fake_output_graph->graph_id());
|
||||
graph_order_type.push_back(COMMON_GRAPH);
|
||||
fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output()));
|
||||
|
@ -1079,7 +1083,7 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id,
|
|||
auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id);
|
||||
auto &graph_order_type = GetGraphOrderType(final_graph_id_);
|
||||
if (cond_graph_index >= graph_order_type.size()) {
|
||||
MS_LOG(EXCEPTION) << "cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size();
|
||||
MS_LOG(EXCEPTION) << "Cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size();
|
||||
}
|
||||
graph_order_type[cond_graph_index] = CONDITION_GRAPH;
|
||||
// update distinction label of false graph,update before merge to sure the distinction
|
||||
|
@ -1156,7 +1160,7 @@ void AscendSession::InsertAllAssigns() {
|
|||
MS_EXCEPTION_IF_NULL(to_graph);
|
||||
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
|
||||
if (input_idx >= graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
|
||||
MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
|
||||
}
|
||||
auto backend_parameter = graph_inputs[input_idx];
|
||||
assigns.emplace_back(std::pair<AnfNodePtr, AnfNodePtr>(front_anf, backend_parameter));
|
||||
|
@ -1180,7 +1184,7 @@ void AscendSession::InsertAllAssigns() {
|
|||
// insert active to graph
|
||||
void AscendSession::SetActive(GraphId from, GraphId to) {
|
||||
if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) {
|
||||
MS_LOG(WARNING) << " to " << to << " has been exits in map,from " << from << ",exist from "
|
||||
MS_LOG(WARNING) << "To " << to << " has been exits in map,from " << from << ",exist from "
|
||||
<< while_condition_graphs_[to];
|
||||
return;
|
||||
}
|
||||
|
@ -1216,7 +1220,7 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId
|
|||
MS_EXCEPTION_IF_NULL(to_graph);
|
||||
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
|
||||
if (input_idx >= graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
|
||||
MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
|
||||
}
|
||||
auto backend_parameter = graph_inputs[input_idx];
|
||||
MS_EXCEPTION_IF_NULL(backend_parameter);
|
||||
|
@ -1250,7 +1254,7 @@ void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor
|
|||
}
|
||||
|
||||
void AscendSession::UpdateGraphOrder(GraphId to_graph_id) {
|
||||
MS_LOG(INFO) << "to_graph_id " << to_graph_id;
|
||||
MS_LOG(INFO) << "To_graph_id " << to_graph_id;
|
||||
auto &graph_order = GetGraphOrder(final_graph_id_);
|
||||
auto &graph_type = GetGraphOrderType(final_graph_id_);
|
||||
for (size_t i = 0; i < graph_order.size(); i++) {
|
||||
|
@ -1268,6 +1272,8 @@ void AscendSession::UpdateGraphOrder(GraphId to_graph_id) {
|
|||
}
|
||||
|
||||
size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
return input_index + output_num;
|
||||
|
@ -1282,6 +1288,7 @@ size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfN
|
|||
}
|
||||
|
||||
size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!value->isa<Tensor>()) {
|
||||
MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString();
|
||||
|
@ -1319,7 +1326,7 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
|
|||
size_t input_index = 0;
|
||||
for (size_t i = 0; i < real_args.size(); i++) {
|
||||
if (input_index >= graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
|
||||
MS_LOG(EXCEPTION) << "Input_index " << input_index << " out of range size " << graph_inputs.size();
|
||||
}
|
||||
auto &real_arg = real_args[i];
|
||||
if (utils::isa<AnfNodePtr>(real_arg)) {
|
||||
|
@ -1351,7 +1358,7 @@ GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
|
|||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(front_anf);
|
||||
MS_LOG(DEBUG) << "front_anf " << front_anf->DebugString() << " is not exist in any graph";
|
||||
MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
|
||||
return kInvalidGraphId;
|
||||
}
|
||||
|
||||
|
@ -1514,7 +1521,7 @@ void AscendSession::SyncInitialTenosrToDevice() {
|
|||
MS_EXCEPTION_IF_NULL(to_graph);
|
||||
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
|
||||
if (input_idx >= graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input_index " << input_idx << " out of range size " << graph_inputs.size();
|
||||
MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size();
|
||||
}
|
||||
auto backend_parameter = graph_inputs[input_idx];
|
||||
// sync data from host to device
|
||||
|
@ -1548,6 +1555,7 @@ static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph,
|
|||
new_kernel_graph->set_return(anf_node);
|
||||
}
|
||||
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_LOG(INFO) << "Output[" << output_idx++ << "]:" << anf_node->DebugString();
|
||||
make_tuple_inputs.push_back(anf_node);
|
||||
}
|
||||
|
@ -1560,7 +1568,7 @@ static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph,
|
|||
std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
|
||||
const std::vector<CNodePtr> &list) {
|
||||
MS_EXCEPTION_IF_NULL(new_kernel_graph);
|
||||
MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id();
|
||||
MS_LOG(INFO) << "Start contruct splited kernel graph:" << new_kernel_graph->graph_id();
|
||||
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
|
||||
std::vector<AnfNodePtr> call_node_inputs;
|
||||
std::vector<AnfNodePtr> new_graph_inputs;
|
||||
|
@ -1604,7 +1612,7 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
|
|||
|
||||
MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id();
|
||||
ConstructSplitedGraphOutput(new_kernel_graph, list);
|
||||
MS_LOG(INFO) << "end";
|
||||
MS_LOG(INFO) << "End";
|
||||
return call_node_inputs;
|
||||
}
|
||||
|
||||
|
@ -1699,7 +1707,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
|
|||
}
|
||||
AscendControlParser::UpdateChildGraphOrder(graph);
|
||||
UpdateRealInput(graph);
|
||||
MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end";
|
||||
MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end";
|
||||
// recurse to split child graph
|
||||
}
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ class AscendSession : public SessionBasic {
|
|||
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void MemoryAlloc(KernelGraph *kernel_graph) const;
|
||||
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
|
||||
void RunOpMemoryClear(KernelGraph *kernel_graph) const;
|
||||
void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
|
||||
void GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void ExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
|
|
@ -33,6 +33,7 @@ namespace mindspore {
|
|||
namespace session {
|
||||
ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
|
|||
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
||||
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(que);
|
||||
MS_EXCEPTION_IF_NULL(visited_nodes);
|
||||
if (visited_nodes->find(node) == visited_nodes->end()) {
|
||||
|
@ -131,11 +132,10 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
|
|||
std::vector<AnfNodePtr> active_nodes;
|
||||
for (const auto &output_edge : it->second) {
|
||||
auto next_node = output_edge.first;
|
||||
MS_EXCEPTION_IF_NULL(next_node);
|
||||
if (node_input_num_.find(next_node) == node_input_num_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(next_node);
|
||||
MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(next_node);
|
||||
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
|
||||
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
|
||||
if (node_input_num_[next_node] < output_edge.second) {
|
||||
|
@ -147,7 +147,7 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
|
|||
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
|
||||
(void)visited_nodes->insert(next_node);
|
||||
if (AnfAlgo::IsCommunicationOp(next_node)) {
|
||||
MS_LOG(DEBUG) << "visit node:" << next_node->DebugString();
|
||||
MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
|
||||
visit_queue->push(next_node);
|
||||
} else {
|
||||
active_nodes.emplace_back(next_node);
|
||||
|
@ -156,7 +156,8 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
|
|||
}
|
||||
|
||||
for (auto &node : active_nodes) {
|
||||
MS_LOG(DEBUG) << "visit node:" << node->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Visit node:" << node->DebugString();
|
||||
visit_queue->push(node);
|
||||
}
|
||||
}
|
||||
|
@ -380,12 +381,12 @@ std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNo
|
|||
auto output_size = AnfAlgo::GetOutputTensorNum(value_node);
|
||||
std::vector<AnfNodePtr> convert_inputs;
|
||||
if (!node_value->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
|
||||
MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
|
||||
}
|
||||
auto value_tuple = node_value->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
if (value_tuple->size() != output_size) {
|
||||
MS_LOG(EXCEPTION) << "value tuple size" << value_tuple->size()
|
||||
MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size()
|
||||
<< " is not mathced with the value node's output size" << output_size;
|
||||
}
|
||||
for (size_t index = 0; index < value_tuple->value().size(); ++index) {
|
||||
|
@ -408,7 +409,7 @@ std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNo
|
|||
convert_inputs.emplace_back(new_value_node);
|
||||
}
|
||||
if (!RemoveValueNodeFromGraph(value_node)) {
|
||||
MS_LOG(WARNING) << "failed to remove the value_node " << value_node->DebugString();
|
||||
MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
|
||||
}
|
||||
return convert_inputs;
|
||||
}
|
||||
|
@ -429,10 +430,10 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode
|
|||
MS_EXCEPTION_IF_NULL(front_anf);
|
||||
MS_EXCEPTION_IF_NULL(backend_anf);
|
||||
if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
|
||||
MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
|
||||
}
|
||||
if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
|
||||
MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
|
||||
}
|
||||
front_backend_anf_map_[front_anf] = backend_anf;
|
||||
backend_front_anf_map_[backend_anf] = front_anf;
|
||||
|
@ -442,15 +443,15 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons
|
|||
MS_EXCEPTION_IF_NULL(old_backend_anf);
|
||||
MS_EXCEPTION_IF_NULL(new_backend_anf);
|
||||
if (old_backend_anf == new_backend_anf) {
|
||||
MS_LOG(DEBUG) << "old same with new:" << old_backend_anf->DebugString();
|
||||
MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString();
|
||||
return;
|
||||
}
|
||||
if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
|
||||
MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
|
||||
MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
|
||||
return;
|
||||
}
|
||||
if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "anf is not exist in the map ,old " << old_backend_anf->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString();
|
||||
}
|
||||
front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf;
|
||||
backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf];
|
||||
|
@ -483,6 +484,8 @@ void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const V
|
|||
}
|
||||
|
||||
void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num;
|
||||
auto output_depend_edge = std::pair<AnfNodePtr, size_t>(node, depend_edge_num);
|
||||
// add output depend edge of input
|
||||
|
@ -625,6 +628,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
|
|||
bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(que);
|
||||
MS_EXCEPTION_IF_NULL(visited_nodes);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -650,6 +655,7 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<Anf
|
|||
}
|
||||
|
||||
void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(seed_nodes);
|
||||
node_output_edges_.clear();
|
||||
node_input_num_.clear();
|
||||
node_input_edges_.clear();
|
||||
|
@ -691,14 +697,14 @@ bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return r
|
|||
|
||||
AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
|
||||
if (!IsInRefOutputMap(out_pair)) {
|
||||
MS_LOG(EXCEPTION) << "out_pair is not in RefOutputMap";
|
||||
MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap";
|
||||
}
|
||||
return ref_out_in_map_.at(out_pair);
|
||||
}
|
||||
|
||||
void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
|
||||
if (IsInRefOutputMap(final_pair)) {
|
||||
MS_LOG(EXCEPTION) << "out_pair is already in RefOutputMap";
|
||||
MS_LOG(EXCEPTION) << "Out_pair is already in RefOutputMap";
|
||||
}
|
||||
(void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair));
|
||||
}
|
||||
|
@ -804,7 +810,7 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
|
|||
void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
MS_LOG(INFO) << "parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
|
||||
MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto iter = std::find_if(
|
||||
|
@ -837,7 +843,7 @@ void KernelGraph::UpdateCallRealInput() {
|
|||
}
|
||||
|
||||
void KernelGraph::PrintGraphExecuteOrder() const {
|
||||
MS_LOG(INFO) << "graph:" << graph_id_ << "execution order";
|
||||
MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order";
|
||||
for (size_t i = 0; i < execution_order_.size(); i++) {
|
||||
CNodePtr cur_cnode_ptr = execution_order_[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
|
@ -859,7 +865,7 @@ void KernelGraph::PrintGraphExecuteOrder() const {
|
|||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
|
||||
MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
|
||||
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
|
||||
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"
|
||||
<< event_str << label_str;
|
||||
|
|
|
@ -75,7 +75,7 @@ class KernelGraph : public FuncGraph {
|
|||
// add value node tensor relation map
|
||||
void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node);
|
||||
// get all value nodes of graph
|
||||
std::unordered_set<ValueNodePtr> graph_value_nodes() { return graph_value_nodes_; }
|
||||
const std::unordered_set<ValueNodePtr> graph_value_nodes() const { return graph_value_nodes_; }
|
||||
// add value node to graph
|
||||
void AddValueNodeToGraph(const ValueNodePtr &value_node);
|
||||
// ref output is in map
|
||||
|
|
|
@ -60,7 +60,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
|
|||
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(INFO) << "create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
|
||||
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
|
||||
// if node is a value node, no need sync addr from device to host
|
||||
if (!AnfAlgo::OutputAddrExist(node, output_index)) {
|
||||
if (node->isa<ValueNode>()) {
|
||||
|
@ -71,13 +71,13 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
if (node->isa<Parameter>()) {
|
||||
for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
|
||||
if (input_idx >= input_tensors.size()) {
|
||||
MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size();
|
||||
MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
|
||||
}
|
||||
if (graph.inputs()[input_idx] == node) {
|
||||
return input_tensors[input_idx];
|
||||
}
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "parameter : " << node->DebugString() << "has no output addr";
|
||||
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr";
|
||||
}
|
||||
}
|
||||
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
|
||||
|
@ -105,7 +105,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
tensor->set_dirty(false);
|
||||
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
|
||||
MS_LOG(INFO) << "output sync device to host error!!!";
|
||||
MS_LOG(INFO) << "Output sync device to host error!!!";
|
||||
tensor->set_dirty(false);
|
||||
}
|
||||
return tensor;
|
||||
|
@ -114,10 +114,10 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_LOG(INFO) << "create tensor for output[" << anf->DebugString() << "]";
|
||||
MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
|
||||
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
||||
MS_LOG(INFO) << "create tensor for output after visit:" << item_with_index.first->DebugString();
|
||||
MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
|
||||
// special handle for maketuple
|
||||
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
|
||||
auto cnode = item_with_index.first->cast<CNodePtr>();
|
||||
|
@ -141,7 +141,7 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
|
|||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (!AnfAlgo::IsRealKernel(anf)) {
|
||||
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] should be a executable kernel";
|
||||
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel";
|
||||
}
|
||||
if (anf->isa<ValueNode>()) {
|
||||
return CreateOneTensor(anf, 0, graph, input_tensors);
|
||||
|
@ -202,7 +202,7 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool va
|
|||
if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
|
||||
MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
|
||||
for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
|
||||
create_parameter((*tuple_abstract)[output_idx]);
|
||||
}
|
||||
|
@ -215,6 +215,7 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool va
|
|||
}
|
||||
|
||||
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Load kInputCtrlTensors";
|
||||
auto inputs_params = graph->input_ctrl_tensors();
|
||||
if (inputs_params == nullptr) {
|
||||
|
@ -282,10 +283,10 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
|
|||
}
|
||||
|
||||
void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
|
||||
MS_LOG(INFO) << "graph outputs:";
|
||||
MS_LOG(INFO) << "Graph outputs:";
|
||||
const size_t max_deep = 10;
|
||||
if (recurse_level > max_deep) {
|
||||
MS_LOG(INFO) << "recurse too deep";
|
||||
MS_LOG(INFO) << "Recurse too deep";
|
||||
return;
|
||||
}
|
||||
std::string tab_str;
|
||||
|
@ -326,7 +327,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
|
|||
KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
|
||||
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto m_tensor = GetParamDefaultInputTensor(anf);
|
||||
|
@ -531,7 +532,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
|
|||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
|
||||
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
|
||||
}
|
||||
|
||||
auto m_tensor = GetParamDefaultInputTensor(anf);
|
||||
|
@ -601,6 +602,22 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|||
return graph;
|
||||
}
|
||||
|
||||
void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// create a new cnode object
|
||||
auto new_cnode = CreateNewCNode(cnode, graph.get());
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
graph->FrontBackendlMapAdd(node, new_cnode);
|
||||
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
|
||||
graph->set_return(new_cnode);
|
||||
}
|
||||
}
|
||||
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||
std::vector<KernelGraphPtr> *all_out_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -638,18 +655,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
}
|
||||
continue;
|
||||
} else {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// create a new cnode object
|
||||
auto new_cnode = CreateNewCNode(cnode, graph.get());
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
graph->FrontBackendlMapAdd(node, new_cnode);
|
||||
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
|
||||
graph->set_return(new_cnode);
|
||||
}
|
||||
CreateCNodeKernelGraph(node, graph);
|
||||
}
|
||||
}
|
||||
// if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
|
||||
|
@ -678,7 +684,7 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶
|
|||
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
|
||||
MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
|
||||
graph_inputs->push_back(backend_parameter);
|
||||
}
|
||||
}
|
||||
|
@ -694,7 +700,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
}
|
||||
auto input_nodes = kernel_graph->inputs();
|
||||
if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
|
||||
MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
||||
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
|
||||
<< ", input_ctrl_size:" << input_ctrl_size;
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -748,7 +754,7 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
auto anf_outputs = kernel_graph->outputs();
|
||||
for (auto &item : anf_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
|
||||
MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
|
||||
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
|
||||
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
|
||||
continue;
|
||||
|
@ -779,7 +785,7 @@ void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
|
|||
auto cnode = n->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() <= kSummaryGetItem) {
|
||||
MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
|
||||
MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!";
|
||||
}
|
||||
auto node = cnode->input(kSummaryGetItem);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -837,7 +843,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
|
|||
std::vector<AnfNodePtr> output_args;
|
||||
for (const auto &output : outputs) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
MS_LOG(INFO) << "output:" << output->DebugString();
|
||||
MS_LOG(INFO) << "Output:" << output->DebugString();
|
||||
}
|
||||
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
|
||||
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
|
||||
|
@ -856,13 +862,13 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr
|
|||
MS_LOG(INFO) << "Start!";
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
|
||||
for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
|
||||
auto idx = NewValueNode(SizeToInt(output_index));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int32Imm>(output_index);
|
||||
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
|
||||
std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
|
||||
std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
|
||||
|
|
|
@ -77,6 +77,8 @@ class SessionBasic {
|
|||
|
||||
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
|
||||
|
||||
void CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph);
|
||||
|
||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
|
||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||
std::vector<KernelGraphPtr> *all_out_graph);
|
||||
|
|
Loading…
Reference in New Issue