use first depend create parameter

This commit is contained in:
chenfei 2020-03-27 14:49:16 +08:00
parent ecc168c72f
commit d1a0ded6c2
9 changed files with 128 additions and 48 deletions

3
mindspore/ccsrc/operator/ops.cc Normal file → Executable file
View File

@ -154,6 +154,9 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
// NN // NN
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");

3
mindspore/ccsrc/operator/ops.h Normal file → Executable file
View File

@ -160,6 +160,9 @@ extern const PrimitivePtr kPrimMul;
extern const PrimitivePtr kPrimMinimum; extern const PrimitivePtr kPrimMinimum;
extern const PrimitivePtr kPrimMaximum; extern const PrimitivePtr kPrimMaximum;
extern const PrimitivePtr kPrimSquare; extern const PrimitivePtr kPrimSquare;
extern const PrimitivePtr kPrimEqual;
extern const PrimitivePtr kPrimLess;
extern const PrimitivePtr kPrimLessEqual;
// NN // NN
extern const PrimitivePtr kPrimFlatten; extern const PrimitivePtr kPrimFlatten;

59
mindspore/ccsrc/session/ascend_session.cc Normal file → Executable file
View File

@ -506,11 +506,13 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
// condition graph's output must be single output auto cond_output_it = condition_output_.find(condition_graph_id);
if (condition_graph->outputs().size() != 1) { if (cond_output_it == condition_output_.end()) {
MS_LOG(EXCEPTION) << "Condition_graph output num " << condition_graph_id << " should be 1"; MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id;
} }
AnfNodePtr cond_output_kernel = condition_graph->outputs()[0]; auto cond_output_kernel =
AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first;
MS_EXCEPTION_IF_NULL(cond_output_kernel);
std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const};
CNodePtr switch_node = condition_graph->NewCNode(inputs); CNodePtr switch_node = condition_graph->NewCNode(inputs);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get()); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get());
@ -569,12 +571,14 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) {
} }
} }
void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id) { void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id,
const AnfNodePtr &output) {
if (switches_.find(cond_graph_id) != switches_.end()) { if (switches_.find(cond_graph_id) != switches_.end()) {
MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before "; MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before ";
return; return;
} }
switches_[cond_graph_id] = std::pair<GraphId, GraphId>(true_graph_id, false_graph_id); switches_[cond_graph_id] = std::pair<GraphId, GraphId>(true_graph_id, false_graph_id);
condition_output_[cond_graph_id] = output;
MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id;
// set the type of condition graph // set the type of condition graph
auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id);
@ -682,12 +686,14 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An
auto from_graph_id = GetGraphIdByNode(front_anf); auto from_graph_id = GetGraphIdByNode(front_anf);
auto from_graph = GetGraph(from_graph_id); auto from_graph = GetGraph(from_graph_id);
MS_EXCEPTION_IF_NULL(from_graph); MS_EXCEPTION_IF_NULL(from_graph);
auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get());
auto to_graph = GetGraph(to_graph_id);
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
MS_EXCEPTION_IF_NULL(to_graph);
MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node["
<< backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get())
<< "]"; << "]";
// a node should not assign to itself // a node should not assign to itself
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
if (backend_arg.get() == backend_parameter.get()) { if (backend_arg.get() == backend_parameter.get()) {
return; return;
} }
@ -703,15 +709,16 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An
return; return;
} }
} }
InsertMultipleAssignToGraph(from_graph_id, backend_arg, backend_parameter); // if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device
// if front anf is a parameter, we can assign the value back, because backend_parameter // type same to arg
// won't be changed in it's graph unless it's a weight. If backend_parameter is a weight, if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) {
// we do should assign the value back. AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get());
auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get()); }
auto to_graph = GetGraph(to_graph_id); InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
MS_EXCEPTION_IF_NULL(to_graph); // if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph
// unless it's a weigth.If backend_parameter is a weight,we do should assign the value back
if (backend_arg->isa<Parameter>() && !to_graph->execution_order().empty()) { if (backend_arg->isa<Parameter>() && !to_graph->execution_order().empty()) {
InsertMultipleAssignToGraph(to_graph_id, backend_parameter, backend_arg); InsertAssignToGraph(to_graph_id, backend_parameter, backend_arg);
} }
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
} }
@ -755,7 +762,25 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
DumpGraphInputArgs(args); DumpGraphInputArgs(args);
UpdateGraphOrder(g); UpdateGraphOrder(g);
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs(); std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
auto valid_inputs = to_graph->ValidInputs();
size_t real_args_size = 0;
for (size_t i = 0; i < args.size(); i++) {
real_args_size += AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem}).size();
}
if (real_args_size != graph_inputs.size()) {
for (size_t j = 0; j < valid_inputs.size(); j++) {
if (valid_inputs[j]) {
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
}
}
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
<< " not equal";
}
size_t input_index = 0; size_t input_index = 0;
if (graph_inputs.size() != valid_inputs.size()) {
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
}
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
if (input_index >= graph_inputs.size()) { if (input_index >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
@ -763,6 +788,10 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
if (utils::isa<AnfNodePtr>(args[i])) { if (utils::isa<AnfNodePtr>(args[i])) {
// arg is a anf node // arg is a anf node
for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem})) { for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem})) {
if (!valid_inputs[input_index]) {
MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString();
continue;
}
SetChildGraphParameter(real_arg, graph_inputs[input_index]); SetChildGraphParameter(real_arg, graph_inputs[input_index]);
input_index++; input_index++;
} }

6
mindspore/ccsrc/session/ascend_session.h Normal file → Executable file
View File

@ -49,9 +49,8 @@ class AscendSession : public SessionBasic {
// set output of final graph // set output of final graph
void SetFinalGraphOutput(const BaseRef &output) override; void SetFinalGraphOutput(const BaseRef &output) override;
// insert switch and set the relative active ops // insert switch and set the relative active ops
void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g) override; void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override;
// set args of child graph. the arg maybe come from a output of other child graphs, // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
// or from final graph's parameter
void SetChildGraphInput(GraphId g, const VectorRef &args) override; void SetChildGraphInput(GraphId g, const VectorRef &args) override;
// get graph id in child graphs by ME front anf node pointer // get graph id in child graphs by ME front anf node pointer
GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
@ -116,6 +115,7 @@ class AscendSession : public SessionBasic {
std::unordered_map<GraphId, GraphId> while_condition_graphs_; std::unordered_map<GraphId, GraphId> while_condition_graphs_;
// record all conditions // record all conditions
std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_; std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
std::unordered_map<GraphId, AnfNodePtr> condition_output_;
// final_graph_id is used in every root graph has it's own session situation // final_graph_id is used in every root graph has it's own session situation
GraphId final_graph_id_; GraphId final_graph_id_;
}; };

3
mindspore/ccsrc/session/kernel_graph.cc Normal file → Executable file
View File

@ -372,8 +372,7 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
MS_EXCEPTION_IF_NULL(depend_node); MS_EXCEPTION_IF_NULL(depend_node);
std::vector<AnfNodePtr> prior_nodes = {prior_node}; std::vector<AnfNodePtr> prior_nodes = {prior_node};
std::vector<AnfNodePtr> depend_nodes = {depend_node}; std::vector<AnfNodePtr> depend_nodes = {depend_node};
MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "],depend node[" << depend_node->DebugString() MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString();
<< "],depend_mode=[" << AnfAlgo::GetNodeAttr<int>(cnode, "depend_mode") << "]";
if (prior_node->isa<Parameter>()) { if (prior_node->isa<Parameter>()) {
prior_nodes = GetOutputNodes(prior_node); prior_nodes = GetOutputNodes(prior_node);
} }

5
mindspore/ccsrc/session/kernel_graph.h Normal file → Executable file
View File

@ -86,6 +86,9 @@ class KernelGraph : public FuncGraph {
bool executable() const { return executable_; } bool executable() const { return executable_; }
// set executable of graph // set executable of graph
void set_executable(bool executable) { executable_ = executable; } void set_executable(bool executable) { executable_ = executable; }
// set invalid inputs for control sink
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
std::vector<bool> ValidInputs() { return valid_inputs_; }
private: private:
// remove value node form graph // remove value node form graph
@ -118,6 +121,8 @@ class KernelGraph : public FuncGraph {
std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_; std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
// graph needn't execute // graph needn't execute
bool executable_; bool executable_;
// valid inputs
std::vector<bool> valid_inputs_;
}; };
} // namespace session } // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;

90
mindspore/ccsrc/session/session_basic.cc Normal file → Executable file
View File

@ -243,29 +243,38 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
return new_value_node; return new_value_node;
} }
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) { ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<Parameter>()) { if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
} }
auto graph_inputs = graph->MutableInputs(); auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs); MS_EXCEPTION_IF_NULL(graph_inputs);
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
graph->FrontBackendlMapAdd(anf, new_parameter);
graph_inputs->push_back(new_parameter); graph_inputs->push_back(new_parameter);
valid_inputs->push_back(valid_input);
return new_parameter; return new_parameter;
} }
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> parameters; std::vector<AnfNodePtr> parameters;
std::vector<AnfNodePtr> pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); std::vector<AnfNodePtr> pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
auto parameter = graph->NewParameter(); auto parameter = graph->NewParameter();
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
parameter->set_abstract(abstract); parameter->set_abstract(abstract);
parameters.push_back(graph->NewParameter(parameter)); auto new_parameter = graph->NewParameter(parameter);
parameters.push_back(new_parameter);
valid_inputs->push_back(valid_input);
graph_inputs->push_back(new_parameter);
}; };
for (const auto &out_node : pre_graph_out) { for (const auto &out_node : pre_graph_out) {
MS_EXCEPTION_IF_NULL(out_node); MS_EXCEPTION_IF_NULL(out_node);
@ -287,18 +296,15 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelG
return parameters; return parameters;
} }
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) { AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<CNode>()) { if (!anf->isa<CNode>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a cnode"; MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode";
} }
MS_LOG(INFO) << "create a new parameter from cnode[" << anf->DebugString() << "]"; MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, graph); auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(*graph_inputs));
if (parameters.empty()) { if (parameters.empty()) {
MS_LOG(EXCEPTION) << "no parameter exist!!"; MS_LOG(EXCEPTION) << "No parameter exist!!";
} }
if (parameters.size() == 1) { if (parameters.size() == 1) {
return parameters[0]; return parameters[0];
@ -307,7 +313,7 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
auto make_tuple = graph->NewCNode(make_tuple_input); auto make_tuple = graph->NewCNode(make_tuple_input);
MS_EXCEPTION_IF_NULL(make_tuple); MS_EXCEPTION_IF_NULL(make_tuple);
MS_LOG(INFO) << "new make tuple [" << make_tuple->DebugString() << "] of parameters"; MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
return make_tuple; return make_tuple;
} }
@ -397,14 +403,20 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
GraphId SessionBasic::graph_sum_ = 0; GraphId SessionBasic::graph_sum_ = 0;
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(from_other_graph);
MS_EXCEPTION_IF_NULL(other_graph_cnode);
*from_other_graph = false;
// get primitive of old node // get primitive of old node
auto prim = AnfAlgo::GetCNodePrimitive(cnode); auto prim = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
// push attr to inputs[0] of new cnode // push attr to inputs[0] of new cnode
std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))}; std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
// if has multiple depends,only select first depend as parameter
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->inputs()[input_idx]; auto anf = cnode->inputs()[input_idx];
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
@ -412,6 +424,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue; continue;
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
cnode_inputs.push_back((*other_graph_cnode)[anf]);
continue;
} else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) { } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
// if input is a value node, // if input is a value node,
auto new_value_node = CreateNewValueNode(anf, graph); auto new_value_node = CreateNewValueNode(anf, graph);
@ -421,38 +436,60 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
continue; continue;
} else if (anf->isa<Parameter>()) { } else if (anf->isa<Parameter>()) {
// if anf is a parameter // if anf is a parameter
cnode_inputs.emplace_back(CreateNewParameterFromParameter(anf, graph)); auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
cnode_inputs.push_back(new_parameter);
if (GetGraphIdByNode(anf) == kInvalidGraphId) {
graph->FrontBackendlMapAdd(anf, new_parameter);
} else {
(*other_graph_cnode)[anf] = new_parameter;
}
continue; continue;
} else if (anf->isa<CNode>()) { } else if (anf->isa<CNode>()) {
*from_other_graph = true;
// the input node is a cnode from other graph // the input node is a cnode from other graph
cnode_inputs.emplace_back(CreateNewParameterFromCNode(anf, graph)); auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
cnode_inputs.push_back(parameter_from_cnode);
(*other_graph_cnode)[anf] = parameter_from_cnode;
continue; continue;
} }
MS_LOG(EXCEPTION) << "unexpected input[" << anf->DebugString() << "]"; MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
} }
return graph->NewCNode(cnode_inputs); TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
TraceManager::EndTrace();
return new_cnode;
} }
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = std::make_shared<KernelGraph>(); auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_); graph->set_graph_id(graph_sum_);
MS_LOG(INFO) << "Create graph: " << graph_sum_;
size_t from_other_graph_depend_num = 0;
for (const auto &node : lst) { for (const auto &node : lst) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "start create new cnode,node = " << node->DebugString(); MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Inst node " << node->DebugString() << " is not CNode"; MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
} }
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
// create a new cnode object // create a new cnode object
auto new_cnode = CreateNewCNode(cnode, graph.get()); bool from_other_graph = false;
// only first depend from other graph can create
bool valid_input = true;
if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
valid_input = false;
}
auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
from_other_graph_depend_num++;
}
MS_EXCEPTION_IF_NULL(new_cnode); MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract()); new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope()); new_cnode->set_scope(cnode->scope());
// record map relations between anf from ME and new anf node used in backend // record map relations between anf from ME and new anf node used in backend
graph->FrontBackendlMapAdd(node, new_cnode); graph->FrontBackendlMapAdd(node, new_cnode);
TraceManager::EndTrace();
} }
// add a make_tuple at the end of graph as output // add a make_tuple at the end of graph as output
graph->set_output(ConstructOutput(outputs, graph)); graph->set_output(ConstructOutput(outputs, graph));
@ -631,12 +668,15 @@ void SessionBasic::ToTensorPtr(const OpRunInfo &op_run_info, std::vector<tensor:
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) { CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> output_args; std::vector<AnfNodePtr> output_args;
auto FindEqu = [graph](const AnfNodePtr &out) -> AnfNodePtr { auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
auto backend_anf = graph->GetBackendAnfByFrontAnf(out); auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
if (backend_anf != nullptr) { if (backend_anf != nullptr) {
return backend_anf; return backend_anf;
} }
MS_LOG(EXCEPTION) << "Can not find the node in the equiv map!"; for (const auto &output : outputs) {
MS_LOG(INFO) << "output:" << output->DebugString();
}
MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
}; };
output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
(void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args), (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),

5
mindspore/ccsrc/session/session_basic.h Normal file → Executable file
View File

@ -69,14 +69,15 @@ class SessionBasic {
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
// set parameters of final graph // set parameters of final graph
virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; } virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; }
// set output of final graph // set output of final graph
virtual void SetFinalGraphOutput(const BaseRef &) {} virtual void SetFinalGraphOutput(const BaseRef &) {}
// insert switch and set the relative active ops // insert switch and set the relative active ops
virtual void SwitchCompile(GraphId, GraphId, GraphId) {} virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {}
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
virtual void SetChildGraphInput(GraphId, const VectorRef &) {} virtual void SetChildGraphInput(GraphId, const VectorRef &) {}
// get graph id in child graphs by ME front anf node pointer // get graph id in child graphs by ME front anf node pointer

2
mindspore/ccsrc/vm/backend.cc Normal file → Executable file
View File

@ -136,7 +136,7 @@ void MsBackend::SetSwitchGraph() {
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
} }
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
sess_->SwitchCompile(cond_g, true_g, false_g); sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
} }
is_switch_call_ = false; is_switch_call_ = false;
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_; MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;