forked from mindspore-Ecosystem/mindspore
use first depend create parameter
This commit is contained in:
parent
ecc168c72f
commit
d1a0ded6c2
|
@ -154,6 +154,9 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
|
|||
const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
|
||||
const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
|
||||
const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
|
||||
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
|
||||
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
|
||||
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
|
||||
|
||||
// NN
|
||||
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
|
|
|
@ -160,6 +160,9 @@ extern const PrimitivePtr kPrimMul;
|
|||
extern const PrimitivePtr kPrimMinimum;
|
||||
extern const PrimitivePtr kPrimMaximum;
|
||||
extern const PrimitivePtr kPrimSquare;
|
||||
extern const PrimitivePtr kPrimEqual;
|
||||
extern const PrimitivePtr kPrimLess;
|
||||
extern const PrimitivePtr kPrimLessEqual;
|
||||
|
||||
// NN
|
||||
extern const PrimitivePtr kPrimFlatten;
|
||||
|
|
|
@ -506,11 +506,13 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true
|
|||
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
|
||||
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
|
||||
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
|
||||
// condition graph's output must be single output
|
||||
if (condition_graph->outputs().size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Condition_graph output num " << condition_graph_id << " should be 1";
|
||||
auto cond_output_it = condition_output_.find(condition_graph_id);
|
||||
if (cond_output_it == condition_output_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id;
|
||||
}
|
||||
AnfNodePtr cond_output_kernel = condition_graph->outputs()[0];
|
||||
auto cond_output_kernel =
|
||||
AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(cond_output_kernel);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const};
|
||||
CNodePtr switch_node = condition_graph->NewCNode(inputs);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get());
|
||||
|
@ -569,12 +571,14 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) {
|
|||
}
|
||||
}
|
||||
|
||||
void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id) {
|
||||
void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id,
|
||||
const AnfNodePtr &output) {
|
||||
if (switches_.find(cond_graph_id) != switches_.end()) {
|
||||
MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before ";
|
||||
return;
|
||||
}
|
||||
switches_[cond_graph_id] = std::pair<GraphId, GraphId>(true_graph_id, false_graph_id);
|
||||
condition_output_[cond_graph_id] = output;
|
||||
MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id;
|
||||
// set the type of condition graph
|
||||
auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id);
|
||||
|
@ -682,12 +686,14 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An
|
|||
auto from_graph_id = GetGraphIdByNode(front_anf);
|
||||
auto from_graph = GetGraph(from_graph_id);
|
||||
MS_EXCEPTION_IF_NULL(from_graph);
|
||||
|
||||
auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get());
|
||||
auto to_graph = GetGraph(to_graph_id);
|
||||
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
|
||||
MS_EXCEPTION_IF_NULL(to_graph);
|
||||
MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node["
|
||||
<< backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get())
|
||||
<< "]";
|
||||
// a node should not assign to itself
|
||||
auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf);
|
||||
if (backend_arg.get() == backend_parameter.get()) {
|
||||
return;
|
||||
}
|
||||
|
@ -703,15 +709,16 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An
|
|||
return;
|
||||
}
|
||||
}
|
||||
InsertMultipleAssignToGraph(from_graph_id, backend_arg, backend_parameter);
|
||||
// if front anf is a parameter, we can assign the value back, because backend_parameter
|
||||
// won't be changed in it's graph unless it's a weight. If backend_parameter is a weight,
|
||||
// we do should assign the value back.
|
||||
auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get());
|
||||
auto to_graph = GetGraph(to_graph_id);
|
||||
MS_EXCEPTION_IF_NULL(to_graph);
|
||||
// if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device
|
||||
// type same to arg
|
||||
if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) {
|
||||
AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get());
|
||||
}
|
||||
InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter);
|
||||
// if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph
|
||||
// unless it's a weigth.If backend_parameter is a weight,we do should assign the value back
|
||||
if (backend_arg->isa<Parameter>() && !to_graph->execution_order().empty()) {
|
||||
InsertMultipleAssignToGraph(to_graph_id, backend_parameter, backend_arg);
|
||||
InsertAssignToGraph(to_graph_id, backend_parameter, backend_arg);
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
@ -755,7 +762,25 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
|
|||
DumpGraphInputArgs(args);
|
||||
UpdateGraphOrder(g);
|
||||
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
|
||||
auto valid_inputs = to_graph->ValidInputs();
|
||||
size_t real_args_size = 0;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
real_args_size += AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem}).size();
|
||||
}
|
||||
if (real_args_size != graph_inputs.size()) {
|
||||
for (size_t j = 0; j < valid_inputs.size(); j++) {
|
||||
if (valid_inputs[j]) {
|
||||
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
|
||||
}
|
||||
}
|
||||
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
|
||||
<< " not equal";
|
||||
}
|
||||
size_t input_index = 0;
|
||||
if (graph_inputs.size() != valid_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
|
||||
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
|
||||
}
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (input_index >= graph_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
|
||||
|
@ -763,6 +788,10 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
|
|||
if (utils::isa<AnfNodePtr>(args[i])) {
|
||||
// arg is a anf node
|
||||
for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem})) {
|
||||
if (!valid_inputs[input_index]) {
|
||||
MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString();
|
||||
continue;
|
||||
}
|
||||
SetChildGraphParameter(real_arg, graph_inputs[input_index]);
|
||||
input_index++;
|
||||
}
|
||||
|
|
|
@ -49,9 +49,8 @@ class AscendSession : public SessionBasic {
|
|||
// set output of final graph
|
||||
void SetFinalGraphOutput(const BaseRef &output) override;
|
||||
// insert switch and set the relative active ops
|
||||
void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g) override;
|
||||
// set args of child graph. the arg maybe come from a output of other child graphs,
|
||||
// or from final graph's parameter
|
||||
void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override;
|
||||
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
|
||||
void SetChildGraphInput(GraphId g, const VectorRef &args) override;
|
||||
// get graph id in child graphs by ME front anf node pointer
|
||||
GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
|
||||
|
@ -116,6 +115,7 @@ class AscendSession : public SessionBasic {
|
|||
std::unordered_map<GraphId, GraphId> while_condition_graphs_;
|
||||
// record all conditions
|
||||
std::unordered_map<GraphId, std::pair<GraphId, GraphId>> switches_;
|
||||
std::unordered_map<GraphId, AnfNodePtr> condition_output_;
|
||||
// final_graph_id is used in every root graph has it's own session situation
|
||||
GraphId final_graph_id_;
|
||||
};
|
||||
|
|
|
@ -372,8 +372,7 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
|
|||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
std::vector<AnfNodePtr> prior_nodes = {prior_node};
|
||||
std::vector<AnfNodePtr> depend_nodes = {depend_node};
|
||||
MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "],depend node[" << depend_node->DebugString()
|
||||
<< "],depend_mode=[" << AnfAlgo::GetNodeAttr<int>(cnode, "depend_mode") << "]";
|
||||
MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString();
|
||||
if (prior_node->isa<Parameter>()) {
|
||||
prior_nodes = GetOutputNodes(prior_node);
|
||||
}
|
||||
|
|
|
@ -86,6 +86,9 @@ class KernelGraph : public FuncGraph {
|
|||
bool executable() const { return executable_; }
|
||||
// set executable of graph
|
||||
void set_executable(bool executable) { executable_ = executable; }
|
||||
// set invalid inputs for control sink
|
||||
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
|
||||
std::vector<bool> ValidInputs() { return valid_inputs_; }
|
||||
|
||||
private:
|
||||
// remove value node form graph
|
||||
|
@ -118,6 +121,8 @@ class KernelGraph : public FuncGraph {
|
|||
std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
|
||||
// graph needn't execute
|
||||
bool executable_;
|
||||
// valid inputs
|
||||
std::vector<bool> valid_inputs_;
|
||||
};
|
||||
} // namespace session
|
||||
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
||||
|
|
|
@ -243,29 +243,38 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
|
|||
return new_value_node;
|
||||
}
|
||||
|
||||
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
|
||||
}
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
auto valid_inputs = graph->MutableValidInputs();
|
||||
MS_EXCEPTION_IF_NULL(valid_inputs);
|
||||
ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
||||
graph->FrontBackendlMapAdd(anf, new_parameter);
|
||||
graph_inputs->push_back(new_parameter);
|
||||
valid_inputs->push_back(valid_input);
|
||||
return new_parameter;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
|
||||
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
std::vector<AnfNodePtr> pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
|
||||
auto valid_inputs = graph->MutableValidInputs();
|
||||
MS_EXCEPTION_IF_NULL(valid_inputs);
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
|
||||
auto parameter = graph->NewParameter();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
parameter->set_abstract(abstract);
|
||||
parameters.push_back(graph->NewParameter(parameter));
|
||||
auto new_parameter = graph->NewParameter(parameter);
|
||||
parameters.push_back(new_parameter);
|
||||
valid_inputs->push_back(valid_input);
|
||||
graph_inputs->push_back(new_parameter);
|
||||
};
|
||||
for (const auto &out_node : pre_graph_out) {
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
|
@ -287,18 +296,15 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelG
|
|||
return parameters;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (!anf->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a cnode";
|
||||
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode";
|
||||
}
|
||||
MS_LOG(INFO) << "create a new parameter from cnode[" << anf->DebugString() << "]";
|
||||
auto parameters = CreateParameterFromTuple(anf, graph);
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(*graph_inputs));
|
||||
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
||||
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
|
||||
if (parameters.empty()) {
|
||||
MS_LOG(EXCEPTION) << "no parameter exist!!";
|
||||
MS_LOG(EXCEPTION) << "No parameter exist!!";
|
||||
}
|
||||
if (parameters.size() == 1) {
|
||||
return parameters[0];
|
||||
|
@ -307,7 +313,7 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph
|
|||
(void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
|
||||
auto make_tuple = graph->NewCNode(make_tuple_input);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
MS_LOG(INFO) << "new make tuple [" << make_tuple->DebugString() << "] of parameters";
|
||||
MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
|
||||
return make_tuple;
|
||||
}
|
||||
|
||||
|
@ -397,14 +403,20 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
|
|||
|
||||
GraphId SessionBasic::graph_sum_ = 0;
|
||||
|
||||
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
|
||||
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
|
||||
bool *from_other_graph,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(from_other_graph);
|
||||
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
||||
*from_other_graph = false;
|
||||
// get primitive of old node
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
// push attr to inputs[0] of new cnode
|
||||
std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
|
||||
// if has multiple depends,only select first depend as parameter
|
||||
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
|
||||
auto anf = cnode->inputs()[input_idx];
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
|
@ -412,6 +424,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|||
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
||||
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
||||
continue;
|
||||
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
|
||||
cnode_inputs.push_back((*other_graph_cnode)[anf]);
|
||||
continue;
|
||||
} else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
|
||||
// if input is a value node,
|
||||
auto new_value_node = CreateNewValueNode(anf, graph);
|
||||
|
@ -421,38 +436,60 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|||
continue;
|
||||
} else if (anf->isa<Parameter>()) {
|
||||
// if anf is a parameter
|
||||
cnode_inputs.emplace_back(CreateNewParameterFromParameter(anf, graph));
|
||||
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
|
||||
cnode_inputs.push_back(new_parameter);
|
||||
if (GetGraphIdByNode(anf) == kInvalidGraphId) {
|
||||
graph->FrontBackendlMapAdd(anf, new_parameter);
|
||||
} else {
|
||||
(*other_graph_cnode)[anf] = new_parameter;
|
||||
}
|
||||
continue;
|
||||
} else if (anf->isa<CNode>()) {
|
||||
*from_other_graph = true;
|
||||
// the input node is a cnode from other graph
|
||||
cnode_inputs.emplace_back(CreateNewParameterFromCNode(anf, graph));
|
||||
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
|
||||
cnode_inputs.push_back(parameter_from_cnode);
|
||||
(*other_graph_cnode)[anf] = parameter_from_cnode;
|
||||
continue;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "unexpected input[" << anf->DebugString() << "]";
|
||||
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
|
||||
}
|
||||
return graph->NewCNode(cnode_inputs);
|
||||
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
|
||||
auto new_cnode = graph->NewCNode(cnode_inputs);
|
||||
TraceManager::EndTrace();
|
||||
return new_cnode;
|
||||
}
|
||||
|
||||
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
|
||||
auto graph = std::make_shared<KernelGraph>();
|
||||
graph->set_graph_id(graph_sum_);
|
||||
MS_LOG(INFO) << "Create graph: " << graph_sum_;
|
||||
size_t from_other_graph_depend_num = 0;
|
||||
for (const auto &node : lst) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "start create new cnode,node = " << node->DebugString();
|
||||
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Inst node " << node->DebugString() << " is not CNode";
|
||||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
|
||||
// create a new cnode object
|
||||
auto new_cnode = CreateNewCNode(cnode, graph.get());
|
||||
bool from_other_graph = false;
|
||||
// only first depend from other graph can create
|
||||
bool valid_input = true;
|
||||
if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
||||
valid_input = false;
|
||||
}
|
||||
auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
|
||||
from_other_graph_depend_num++;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
// record map relations between anf from ME and new anf node used in backend
|
||||
graph->FrontBackendlMapAdd(node, new_cnode);
|
||||
TraceManager::EndTrace();
|
||||
}
|
||||
// add a make_tuple at the end of graph as output
|
||||
graph->set_output(ConstructOutput(outputs, graph));
|
||||
|
@ -631,12 +668,15 @@ void SessionBasic::ToTensorPtr(const OpRunInfo &op_run_info, std::vector<tensor:
|
|||
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> output_args;
|
||||
auto FindEqu = [graph](const AnfNodePtr &out) -> AnfNodePtr {
|
||||
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
|
||||
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
|
||||
if (backend_anf != nullptr) {
|
||||
return backend_anf;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Can not find the node in the equiv map!";
|
||||
for (const auto &output : outputs) {
|
||||
MS_LOG(INFO) << "output:" << output->DebugString();
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
|
||||
};
|
||||
output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
(void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
|
||||
|
|
|
@ -69,14 +69,15 @@ class SessionBasic {
|
|||
|
||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
|
||||
|
||||
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);
|
||||
CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
||||
|
||||
// set parameters of final graph
|
||||
virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; }
|
||||
// set output of final graph
|
||||
virtual void SetFinalGraphOutput(const BaseRef &) {}
|
||||
// insert switch and set the relative active ops
|
||||
virtual void SwitchCompile(GraphId, GraphId, GraphId) {}
|
||||
virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {}
|
||||
// set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter
|
||||
virtual void SetChildGraphInput(GraphId, const VectorRef &) {}
|
||||
// get graph id in child graphs by ME front anf node pointer
|
||||
|
|
|
@ -136,7 +136,7 @@ void MsBackend::SetSwitchGraph() {
|
|||
MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
|
||||
sess_->SwitchCompile(cond_g, true_g, false_g);
|
||||
sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
|
||||
}
|
||||
is_switch_call_ = false;
|
||||
MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
|
||||
|
|
Loading…
Reference in New Issue