use first depend create parameter

This commit is contained in:
chenfei 2020-03-27 14:49:16 +08:00
parent ecc168c72f
commit d1a0ded6c2
9 changed files with 128 additions and 48 deletions

3
mindspore/ccsrc/operator/ops.cc Normal file → Executable file
View File

@ -154,6 +154,9 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
// NN
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");

3
mindspore/ccsrc/operator/ops.h Normal file → Executable file
View File

@ -160,6 +160,9 @@ extern const PrimitivePtr kPrimMul;
extern const PrimitivePtr kPrimMinimum;
extern const PrimitivePtr kPrimMaximum;
extern const PrimitivePtr kPrimSquare;
extern const PrimitivePtr kPrimEqual;
extern const PrimitivePtr kPrimLess;
extern const PrimitivePtr kPrimLessEqual;
// NN
extern const PrimitivePtr kPrimFlatten;

59
mindspore/ccsrc/session/ascend_session.cc Normal file → Executable file
View File

@ -506,11 +506,13 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
// condition graph's output must be single output
if (condition_graph->outputs().size() != 1) {
MS_LOG(EXCEPTION) << "Condition_graph output num " << condition_graph_id << " should be 1";
auto cond_output_it = condition_output_.find(condition_graph_id);
if (cond_output_it == condition_output_.end()) {
MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id;
}
AnfNodePtr cond_output_kernel = condition_graph->outputs()[0];
auto cond_output_kernel =
AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first;
MS_EXCEPTION_IF_NULL(cond_output_kernel);
std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const};
CNodePtr switch_node = condition_graph->NewCNode(inputs);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get());
@ -569,12 +571,14 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) {
}
}
void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id) {
void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id,
const AnfNodePtr &output) {
if (switches_.find(cond_graph_id) != switches_.end()) {
MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before ";
return;
}
switches_[cond_graph_id] = std::pair<GraphId, GraphId>(true_graph_id, false_graph_id);
condition_output_[cond_graph_id] = output;
MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id;
// set the type of condition graph
auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id);
@ -682,12 +686,14 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An
auto from_graph_id = GetGraphIdByNode(front_anf);
auto from_graph = GetGraph(from_graph_id);
MS_EXCEPTION_IF_NULL(from_graph);
auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get());
auto to_graph = GetGraph(to_graph_id);
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
MS_EXCEPTION_IF_NULL(to_graph);
MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node["
<< backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get())
<< "]";
// a node should not assign to itself
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
if (backend_arg.get() == backend_parameter.get()) {
return;
}
@ -703,15 +709,16 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An
return;
}
}
InsertMultipleAssignToGraph(from_graph_id, backend_arg, backend_parameter);
// if front anf is a parameter, we can assign the value back, because backend_parameter
// won't be changed in it's graph unless it's a weight. If backend_parameter is a weight,
// we do should assign the value back.
auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get());
auto to_graph = GetGraph(to_graph_id);
MS_EXCEPTION_IF_NULL(to_graph);
// if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device
// type same to arg
if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) {
AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get());
}
InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
// if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph
// unless it's a weigth.If backend_parameter is a weight,we do should assign the value back
if (backend_arg->isa<Parameter>() && !to_graph->execution_order().empty()) {
InsertMultipleAssignToGraph(to_graph_id, backend_parameter, backend_arg);
InsertAssignToGraph(to_graph_id, backend_parameter, backend_arg);
}
MS_LOG(INFO) << "Finish!";
}
@ -755,7 +762,25 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
DumpGraphInputArgs(args);
UpdateGraphOrder(g);
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
auto valid_inputs = to_graph->ValidInputs();
size_t real_args_size = 0;
for (size_t i = 0; i < args.size(); i++) {
real_args_size += AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem}).size();
}
if (real_args_size != graph_inputs.size()) {
for (size_t j = 0; j < valid_inputs.size(); j++) {
if (valid_inputs[j]) {
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
}
}
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
<< " not equal";
}
size_t input_index = 0;
if (graph_inputs.size() != valid_inputs.size()) {
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
}
for (size_t i = 0; i < args.size(); i++) {
if (input_index >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
@ -763,6 +788,10 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
if (utils::isa<AnfNodePtr>(args[i])) {
// arg is a anf node
for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem})) {
if (!valid_inputs[input_index]) {
MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString();
continue;
}
SetChildGraphParameter(real_arg, graph_inputs[input_index]);
input_index++;
}

6
mindspore/ccsrc/session/ascend_session.h Normal file → Executable file
View File

@ -49,9 +49,8 @@ class AscendSession : public SessionBasic {
// set output of final graph
void SetFinalGraphOutput(const BaseRef &output) override;
// insert switch and set the relative active ops
void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g) override;
// set args of child graph. the arg maybe come from a output of other child graphs,
// or from final graph's parameter
void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override;
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
void SetChildGraphInput(GraphId g, const VectorRef &args) override;
// get graph id in child graphs by ME front anf node pointer
GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
@ -116,6 +115,7 @@ class AscendSession : public SessionBasic {
std::unordered_map<GraphId, GraphId> while_condition_graphs_;
// record all conditions
std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
std::unordered_map<GraphId, AnfNodePtr> condition_output_;
// final_graph_id is used in every root graph has it's own session situation
GraphId final_graph_id_;
};

3
mindspore/ccsrc/session/kernel_graph.cc Normal file → Executable file
View File

@ -372,8 +372,7 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
MS_EXCEPTION_IF_NULL(depend_node);
std::vector<AnfNodePtr> prior_nodes = {prior_node};
std::vector<AnfNodePtr> depend_nodes = {depend_node};
MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "],depend node[" << depend_node->DebugString()
<< "],depend_mode=[" << AnfAlgo::GetNodeAttr<int>(cnode, "depend_mode") << "]";
MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString();
if (prior_node->isa<Parameter>()) {
prior_nodes = GetOutputNodes(prior_node);
}

5
mindspore/ccsrc/session/kernel_graph.h Normal file → Executable file
View File

@ -86,6 +86,9 @@ class KernelGraph : public FuncGraph {
bool executable() const { return executable_; }
// set executable of graph
void set_executable(bool executable) { executable_ = executable; }
// set invalid inputs for control sink
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
std::vector<bool> ValidInputs() { return valid_inputs_; }
private:
// remove value node form graph
@ -118,6 +121,8 @@ class KernelGraph : public FuncGraph {
std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
// graph needn't execute
bool executable_;
// valid inputs
std::vector<bool> valid_inputs_;
};
} // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;

90
mindspore/ccsrc/session/session_basic.cc Normal file → Executable file
View File

@ -243,29 +243,38 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
return new_value_node;
}
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
graph->FrontBackendlMapAdd(anf, new_parameter);
graph_inputs->push_back(new_parameter);
valid_inputs->push_back(valid_input);
return new_parameter;
}
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> parameters;
std::vector<AnfNodePtr> pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
auto parameter = graph->NewParameter();
MS_EXCEPTION_IF_NULL(parameter);
parameter->set_abstract(abstract);
parameters.push_back(graph->NewParameter(parameter));
auto new_parameter = graph->NewParameter(parameter);
parameters.push_back(new_parameter);
valid_inputs->push_back(valid_input);
graph_inputs->push_back(new_parameter);
};
for (const auto &out_node : pre_graph_out) {
MS_EXCEPTION_IF_NULL(out_node);
@ -287,18 +296,15 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelG
return parameters;
}
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<CNode>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a cnode";
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode";
}
MS_LOG(INFO) << "create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, graph);
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(*graph_inputs));
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
if (parameters.empty()) {
MS_LOG(EXCEPTION) << "no parameter exist!!";
MS_LOG(EXCEPTION) << "No parameter exist!!";
}
if (parameters.size() == 1) {
return parameters[0];
@ -307,7 +313,7 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
auto make_tuple = graph->NewCNode(make_tuple_input);
MS_EXCEPTION_IF_NULL(make_tuple);
MS_LOG(INFO) << "new make tuple [" << make_tuple->DebugString() << "] of parameters";
MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
return make_tuple;
}
@ -397,14 +403,20 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
GraphId SessionBasic::graph_sum_ = 0;
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(from_other_graph);
MS_EXCEPTION_IF_NULL(other_graph_cnode);
*from_other_graph = false;
// get primitive of old node
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
// push attr to inputs[0] of new cnode
std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
// if has multiple depends,only select first depend as parameter
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->inputs()[input_idx];
MS_EXCEPTION_IF_NULL(anf);
@ -412,6 +424,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
cnode_inputs.push_back((*other_graph_cnode)[anf]);
continue;
} else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
// if input is a value node,
auto new_value_node = CreateNewValueNode(anf, graph);
@ -421,38 +436,60 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
continue;
} else if (anf->isa<Parameter>()) {
// if anf is a parameter
cnode_inputs.emplace_back(CreateNewParameterFromParameter(anf, graph));
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
cnode_inputs.push_back(new_parameter);
if (GetGraphIdByNode(anf) == kInvalidGraphId) {
graph->FrontBackendlMapAdd(anf, new_parameter);
} else {
(*other_graph_cnode)[anf] = new_parameter;
}
continue;
} else if (anf->isa<CNode>()) {
*from_other_graph = true;
// the input node is a cnode from other graph
cnode_inputs.emplace_back(CreateNewParameterFromCNode(anf, graph));
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
cnode_inputs.push_back(parameter_from_cnode);
(*other_graph_cnode)[anf] = parameter_from_cnode;
continue;
}
MS_LOG(EXCEPTION) << "unexpected input[" << anf->DebugString() << "]";
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
}
return graph->NewCNode(cnode_inputs);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
TraceManager::EndTrace();
return new_cnode;
}
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = std::make_shared<KernelGraph>();
graph->set_graph_id(graph_sum_);
MS_LOG(INFO) << "Create graph: " << graph_sum_;
size_t from_other_graph_depend_num = 0;
for (const auto &node : lst) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "start create new cnode,node = " << node->DebugString();
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Inst node " << node->DebugString() << " is not CNode";
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
// create a new cnode object
auto new_cnode = CreateNewCNode(cnode, graph.get());
bool from_other_graph = false;
// only first depend from other graph can create
bool valid_input = true;
if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
valid_input = false;
}
auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
from_other_graph_depend_num++;
}
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
// record map relations between anf from ME and new anf node used in backend
graph->FrontBackendlMapAdd(node, new_cnode);
TraceManager::EndTrace();
}
// add a make_tuple at the end of graph as output
graph->set_output(ConstructOutput(outputs, graph));
@ -631,12 +668,15 @@ void SessionBasic::ToTensorPtr(const OpRunInfo &op_run_info, std::vector<tensor:
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> output_args;
auto FindEqu = [graph](const AnfNodePtr &out) -> AnfNodePtr {
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
if (backend_anf != nullptr) {
return backend_anf;
}
MS_LOG(EXCEPTION) << "Can not find the node in the equiv map!";
for (const auto &output : outputs) {
MS_LOG(INFO) << "output:" << output->DebugString();
}
MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
};
output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
(void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),

5
mindspore/ccsrc/session/session_basic.h Normal file → Executable file
View File

@ -69,14 +69,15 @@ class SessionBasic {
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);
CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
// set parameters of final graph
virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; }
// set output of final graph
virtual void SetFinalGraphOutput(const BaseRef &) {}
// insert switch and set the relative active ops
virtual void SwitchCompile(GraphId, GraphId, GraphId) {}
virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {}
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
virtual void SetChildGraphInput(GraphId, const VectorRef &) {}
// get graph id in child graphs by ME front anf node pointer

2
mindspore/ccsrc/vm/backend.cc Normal file → Executable file
View File

@ -136,7 +136,7 @@ void MsBackend::SetSwitchGraph() {
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
}
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
sess_->SwitchCompile(cond_g, true_g, false_g);
sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
}
is_switch_call_ = false;
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;