forked from mindspore-Ecosystem/mindspore
!22034 mindir:support mindir control flow graph to compile with single pipeline.
Merge pull request !22034 from lanzhineng/mindir_control_flow
This commit is contained in:
commit
bd72a5a851
|
@ -36,6 +36,7 @@
|
|||
#include "pipeline/jit/static_analysis/remove_monad.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
#include "pipeline/jit/static_analysis/program_specialize.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -195,82 +196,6 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
|||
return ret;
|
||||
}
|
||||
|
||||
const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto manager = res->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
FuncGraphPtr loaded_graph = nullptr;
|
||||
size_t loaded_graph_num = 0;
|
||||
auto all_graphs = manager->func_graphs();
|
||||
for (auto &graph : all_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (graph->has_attr("is_load")) {
|
||||
loaded_graph = graph;
|
||||
loaded_graph_num += 1;
|
||||
res->set_is_load(true);
|
||||
}
|
||||
}
|
||||
if (loaded_graph_num == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (loaded_graph_num == 1) {
|
||||
return loaded_graph;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num;
|
||||
}
|
||||
|
||||
void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto manager = res->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
FuncGraphPtr root_graph = *(manager->roots().begin());
|
||||
auto root_inputs = root_graph->get_inputs();
|
||||
auto loaded_inputs = loaded_graph->get_inputs();
|
||||
MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString();
|
||||
MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString();
|
||||
|
||||
size_t root_inputs_num = root_inputs.size();
|
||||
size_t loaded_inputs_num = loaded_inputs.size();
|
||||
if (root_inputs_num != loaded_inputs_num) {
|
||||
MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph "
|
||||
<< loaded_inputs_num;
|
||||
}
|
||||
for (size_t index = 0; index < root_inputs_num; index++) {
|
||||
auto root_input = root_inputs[index];
|
||||
auto loaded_input = loaded_inputs[index];
|
||||
|
||||
MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1);
|
||||
MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1);
|
||||
MS_LOG(DEBUG) << "root_input abstract[" << index
|
||||
<< "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL");
|
||||
MS_LOG(DEBUG) << "loaded_input abstract [" << index
|
||||
<< "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL");
|
||||
|
||||
auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
|
||||
auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
|
||||
auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
|
||||
auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
|
||||
|
||||
MS_EXCEPTION_IF_NULL(root_shape);
|
||||
MS_EXCEPTION_IF_NULL(loaded_shape);
|
||||
MS_EXCEPTION_IF_NULL(root_type);
|
||||
MS_EXCEPTION_IF_NULL(loaded_type);
|
||||
|
||||
auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
|
||||
(root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
|
||||
if (!shapeEqu) {
|
||||
MS_EXCEPTION(ValueError) << "The " << index
|
||||
<< " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
|
||||
<< ", input shape of loaded graph: " << loaded_shape->ToString();
|
||||
}
|
||||
if (root_type->type_id() != loaded_type->type_id()) {
|
||||
MS_EXCEPTION(TypeError) << "The " << std::to_string(index)
|
||||
<< " th input type differ from loaded graph. Input type: " << root_type->ToString()
|
||||
<< ", input type of loaded graph: " << loaded_type->ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ParseAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (!res->input()) {
|
||||
|
@ -453,8 +378,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
|
||||
context->ParallelParameterContextInitShape(func_graph);
|
||||
|
||||
// get original loaded graph to check inputs later
|
||||
auto loaded_graph_ptr = GetLoadedGraph(res);
|
||||
// suppose that there is not KeywordArgument for the top graph
|
||||
// get the hyper parameter
|
||||
for (const auto ¶m : func_graph->parameters()) {
|
||||
|
@ -491,10 +414,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|||
}
|
||||
}
|
||||
}
|
||||
// check input after abstract when there is a loaded graph
|
||||
if (loaded_graph_ptr != nullptr) {
|
||||
CheckRootInputShapeAndType(res, loaded_graph_ptr);
|
||||
}
|
||||
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
|
||||
return true;
|
||||
}
|
||||
|
@ -823,6 +742,66 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
|
|||
bool PipelineSplitAction(const ResourcePtr &res) { return PipelineSplitPass(res); }
|
||||
bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
|
||||
|
||||
bool SetMindIRGraphAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
res->set_is_load(true);
|
||||
auto cell = py::cast<CellPtr>(res->input());
|
||||
if (cell == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The graph loaded from mindir is null.";
|
||||
}
|
||||
const std::string mindir_graph = "graph_load_from_mindir";
|
||||
auto obj = cell->GetAttr(mindir_graph);
|
||||
if (obj == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The graph loaded from mindir is null. The cell has not attribute: " << mindir_graph;
|
||||
}
|
||||
auto fg = GetValue<FuncGraphPtr>(obj);
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The graph loaded from mindir is null.";
|
||||
}
|
||||
res->set_func_graph(fg);
|
||||
FuncGraphManagerPtr mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
auto res_mng = res->manager();
|
||||
MS_EXCEPTION_IF_NULL(res_mng);
|
||||
res_mng->AddFuncGraph(fg);
|
||||
fg->set_manager(res_mng);
|
||||
}
|
||||
abstract::AbstractBasePtrList broaded_args;
|
||||
const auto &args_spec_list = res->args_spec();
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_args),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (arg->GetValueTrack() != kAnyValue) {
|
||||
return arg->Broaden();
|
||||
}
|
||||
return arg;
|
||||
});
|
||||
|
||||
// suppose that there is not KeywordArgument for the top graph
|
||||
// get the hyper parameter
|
||||
for (const auto ¶m : fg->parameters()) {
|
||||
auto param_node = std::static_pointer_cast<Parameter>(param);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
if (param_node->has_default()) {
|
||||
auto value = param_node->default_param();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
|
||||
auto ref_key = std::make_shared<RefKey>(param_node->name());
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
|
||||
broaded_args.push_back(abs_ref);
|
||||
}
|
||||
}
|
||||
auto result = AbstractAnalyze(res, res->func_graph(), broaded_args, true);
|
||||
auto it = abstract::AnalysisResultCacheMgr::GetInstance().begin();
|
||||
auto it_end = abstract::AnalysisResultCacheMgr::GetInstance().end();
|
||||
for (; it != it_end; ++it) {
|
||||
it->first->node()->set_abstract(it->second->abstract());
|
||||
}
|
||||
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
|
||||
MS_EXCEPTION_IF_NULL(res->manager());
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
|
@ -963,7 +942,17 @@ std::vector<ActionItem> BackendPipeline() {
|
|||
(void)actions.emplace_back(std::make_pair("execute", ExecuteAction));
|
||||
return actions;
|
||||
}
|
||||
|
||||
std::vector<ActionItem> MindIRPipeline() {
|
||||
std::vector<ActionItem> actions;
|
||||
// Set funcGraph loaded from MindIR to resource.
|
||||
(void)actions.emplace_back(std::make_pair("load_mindir", SetMindIRGraphAction));
|
||||
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
// compile the ANF graph
|
||||
(void)actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
|
||||
// to execute the graph
|
||||
(void)actions.emplace_back(std::make_pair("execute", ExecuteAction));
|
||||
return actions;
|
||||
}
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
std::vector<ActionItem> ServerPipeline() {
|
||||
auto actions = CommonPipeline();
|
||||
|
|
|
@ -49,6 +49,7 @@ bool StartServerAction(const ResourcePtr &res);
|
|||
|
||||
std::vector<ActionItem> GePipeline();
|
||||
std::vector<ActionItem> VmPipeline();
|
||||
std::vector<ActionItem> MindIRPipeline();
|
||||
std::vector<ActionItem> BackendPipeline();
|
||||
std::vector<ActionItem> PServerPipeline();
|
||||
std::vector<ActionItem> ServerPipeline();
|
||||
|
|
|
@ -612,6 +612,11 @@ bool IsPhaseTrain(const std::string &phase_s) {
|
|||
return phase_s.rfind(phase_to_train) != std::string::npos;
|
||||
}
|
||||
|
||||
bool IsPhaseLoadFromMindIR(const std::string &phase_s) {
|
||||
const std::string mindir_graph = "graph_load_from_mindir";
|
||||
return phase_s.rfind(mindir_graph) != std::string::npos;
|
||||
}
|
||||
|
||||
std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
bool is_air = IsPhaseExportAir(phase_s);
|
||||
|
@ -646,6 +651,9 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
resource->func_graph() != nullptr) {
|
||||
return BackendPipeline();
|
||||
}
|
||||
if (IsPhaseLoadFromMindIR(phase_s)) {
|
||||
return MindIRPipeline();
|
||||
}
|
||||
return VmPipeline();
|
||||
}
|
||||
return GePipeline();
|
||||
|
|
|
@ -293,6 +293,11 @@ class EvaluatorCacheMgr {
|
|||
// AnalysisCache
|
||||
class AnalysisResultCacheMgr {
|
||||
public:
|
||||
using AnalysisConfigResultMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
|
||||
using const_iterator = typename AnalysisConfigResultCache::const_iterator;
|
||||
|
||||
~AnalysisResultCacheMgr() = default;
|
||||
AnalysisResultCacheMgr(const AnalysisResultCacheMgr &) = delete;
|
||||
AnalysisResultCacheMgr &operator=(const AnalysisResultCacheMgr &) = delete;
|
||||
|
@ -306,17 +311,14 @@ class AnalysisResultCacheMgr {
|
|||
AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
AbstractBasePtr TryGetSwitchValue(const AnfNodeConfigPtr &conf);
|
||||
void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &vale);
|
||||
const_iterator begin() { return cache_.begin(); }
|
||||
const_iterator end() { return cache_.end(); }
|
||||
|
||||
private:
|
||||
using AnalysisConfigAsyncResultMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, AsyncAbstractPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
using AnalysisConfigAsyncResultCache =
|
||||
MultiThreadCache<AnfNodeConfigPtr, AsyncAbstractPtr, AnalysisConfigAsyncResultMap>;
|
||||
|
||||
using AnalysisConfigResultMap =
|
||||
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
|
||||
|
||||
AnalysisResultCacheMgr() = default;
|
||||
static AnalysisResultCacheMgr instance_;
|
||||
std::mutex lock_;
|
||||
|
|
|
@ -317,11 +317,10 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
|
|||
} else if (type->isa<Tuple>()) {
|
||||
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
|
||||
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
|
||||
} else if (type->isa<Number>() || type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
|
||||
value_proto->set_denotation(type->type_name());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!";
|
||||
value_proto->set_denotation(type->type_name());
|
||||
}
|
||||
MS_LOG(DEBUG) << "Value type: " << type->type_name();
|
||||
}
|
||||
|
||||
void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
|
@ -366,7 +365,6 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::
|
|||
|
||||
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
||||
bool is_only_return = true;
|
||||
for (const AnfNodePtr &node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
|
@ -375,13 +373,9 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == func_graph->get_return()) {
|
||||
if (is_only_return) {
|
||||
MS_LOG(EXCEPTION) << "Only has return node, can't convert to binary model!";
|
||||
}
|
||||
BuildOutput(cnode, graph_proto);
|
||||
} else {
|
||||
BuildCNode(cnode, graph_proto);
|
||||
is_only_return = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -393,11 +387,9 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *con
|
|||
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
|
||||
}
|
||||
AnfNodePtr arg = node->input(1);
|
||||
std::string node_name = BuildInputNode(arg, graph_proto);
|
||||
mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
|
||||
std::string output_name = GetUniqueNodeName(node);
|
||||
output_proto->set_name(output_name);
|
||||
MS_EXCEPTION_IF_NULL(last_node_);
|
||||
last_node_->set_output(0, output_name);
|
||||
output_proto->set_name(node_name);
|
||||
SetValueInfoProto(arg, output_proto);
|
||||
}
|
||||
|
||||
|
|
|
@ -307,14 +307,16 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
|
|||
node->set_debug_info(debug_info_ptr);
|
||||
node->set_name(debug_info_name);
|
||||
|
||||
// Set abstract of the parameter
|
||||
if (value_proto.tensor_size() > 0) {
|
||||
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
|
||||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
node->set_abstract(tensor_abstract);
|
||||
} else if (value_proto.has_denotation()) {
|
||||
MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation();
|
||||
}
|
||||
|
||||
anfnode_build_map_[value_proto.name()] = node;
|
||||
return true;
|
||||
}
|
||||
|
@ -776,15 +778,11 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
|
|||
// Operator maybe CNode,FuncGraph or Parameter.
|
||||
|
||||
if (node_type.size() > kOpTypeFlagSize && node_type.substr(0, kOpTypeFlagSize) == kOperatorTypeFlag) {
|
||||
auto it = anfnode_build_map_.find(node_type.substr(kOpTypeFlagSize));
|
||||
if (it != anfnode_build_map_.end()) {
|
||||
auto funcGraph = GetValueNode<FuncGraphPtr>(it->second);
|
||||
if (funcGraph != nullptr) {
|
||||
return NewValueNode(funcGraph);
|
||||
}
|
||||
return it->second;
|
||||
auto anfNode = GetAnfNode(node_type.substr(kOpTypeFlagSize));
|
||||
if (anfNode == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Can't find the ref:" << node_type;
|
||||
return anfNode;
|
||||
}
|
||||
|
||||
// Operator is primitive.
|
||||
|
@ -799,6 +797,7 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
prim->set_instance_name(op_name);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Special node_type: " << node_type;
|
||||
prim = std::make_shared<Primitive>(node_type);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
prim->set_instance_name(node_type);
|
||||
|
@ -903,12 +902,12 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
|||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(BuildOperatorNode(node_proto));
|
||||
for (int i = 0; i < node_proto.input_size(); ++i) {
|
||||
const std::string &input_name = node_proto.input(i);
|
||||
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
|
||||
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
|
||||
auto anfNode = GetAnfNode(node_proto.input(i));
|
||||
if (anfNode == nullptr) {
|
||||
MS_LOG(ERROR) << node_name << " input " << i << node_proto.input(i) << "can't find in nodes have parsed";
|
||||
return nullptr;
|
||||
}
|
||||
inputs.push_back(anfnode_build_map_[input_name]);
|
||||
inputs.push_back(anfNode);
|
||||
}
|
||||
|
||||
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
|
@ -929,9 +928,8 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
|||
}
|
||||
|
||||
bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const mind_ir::GraphProto &importProto, const CNodePtr &cnode_ptr) {
|
||||
const mind_ir::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
if (importProto.output_size() < 0 || importProto.output_size() > INT_MAX) {
|
||||
MS_LOG(ERROR) << "importProto.output_size is : " << importProto.output_size();
|
||||
return false;
|
||||
|
@ -944,12 +942,13 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
|
|||
for (int out_size = 0; out_size < importProto.output_size(); ++out_size) {
|
||||
const mind_ir::ValueInfoProto &output_node = importProto.output(out_size);
|
||||
const std::string &out_tuple = output_node.name();
|
||||
if (anfnode_build_map_.find(out_tuple) == anfnode_build_map_.end()) {
|
||||
MS_LOG(ERROR) << "Can't find out_tuple in anfnode_build_map_";
|
||||
auto anfNode = GetAnfNode(out_tuple);
|
||||
if (anfNode == nullptr) {
|
||||
MS_LOG(ERROR) << "Miss return node: " << out_tuple;
|
||||
return false;
|
||||
}
|
||||
inputs.push_back(anfnode_build_map_[out_tuple]);
|
||||
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
|
||||
inputs.push_back(anfNode);
|
||||
elem.push_back(anfNode->abstract());
|
||||
}
|
||||
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(maketuple_ptr);
|
||||
|
@ -961,16 +960,22 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
|
|||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
return_node->set_load_flag(true);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
|
||||
MS_LOG(DEBUG) << "Construct funcgraph finined, all success.";
|
||||
} else {
|
||||
inputs.clear();
|
||||
inputs.push_back(NewValueNode(prim::kPrimReturn));
|
||||
inputs.push_back(cnode_ptr);
|
||||
auto nodeName = importProto.output(0).name();
|
||||
auto anfNode = GetAnfNode(nodeName);
|
||||
if (anfNode == nullptr) {
|
||||
MS_LOG(ERROR) << "Miss return node: " << nodeName;
|
||||
return false;
|
||||
}
|
||||
inputs.push_back(anfNode);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
return_node->set_load_flag(true);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
|
||||
MS_LOG(DEBUG) << "Construct funcgraph finined, all success!";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -982,7 +987,7 @@ bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
|||
MS_LOG(ERROR) << "importProto.node_size is : " << importProto.node_size();
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
|
||||
MS_LOG(DEBUG) << "The node size : " << importProto.node_size();
|
||||
CNodePtr cnode_ptr = nullptr;
|
||||
for (int i = 0; i < importProto.node_size(); ++i) {
|
||||
const mind_ir::NodeProto &node_proto = importProto.node(i);
|
||||
|
@ -1001,7 +1006,7 @@ bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
|||
}
|
||||
}
|
||||
|
||||
return BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
return BuildReturnForFuncGraph(outputFuncGraph, importProto);
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) {
|
||||
|
@ -1092,4 +1097,17 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
|
|||
anfnode_build_map_.clear();
|
||||
return dstGraph;
|
||||
}
|
||||
|
||||
AnfNodePtr MSANFModelParser::GetAnfNode(const std::string &node_name) {
|
||||
auto it = anfnode_build_map_.find(node_name);
|
||||
if (it == anfnode_build_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
FuncGraphPtr func_graph_ptr = GetValueNode<FuncGraphPtr>(it->second);
|
||||
if (func_graph_ptr) {
|
||||
return NewValueNode(func_graph_ptr);
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -52,8 +52,7 @@ class MSANFModelParser {
|
|||
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
|
||||
tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto);
|
||||
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto);
|
||||
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto,
|
||||
const CNodePtr &cnode_ptr);
|
||||
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
||||
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
||||
void ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
|
||||
|
@ -72,6 +71,7 @@ class MSANFModelParser {
|
|||
bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> GetAbstractForCNode(
|
||||
const mind_ir::AttributeProto &attr_proto);
|
||||
AnfNodePtr GetAnfNode(const std::string &node_name);
|
||||
|
||||
std::string producer_name_;
|
||||
std::string model_version_;
|
||||
|
|
|
@ -1458,4 +1458,6 @@ class GraphCell(Cell):
|
|||
return self.graph(*inputs)
|
||||
|
||||
def __call__(self, *inputs):
|
||||
self.phase = "graph_load_from_mindir"
|
||||
self._add_attr("graph_load_from_mindir", self.graph)
|
||||
return self.compile_and_run(*inputs)
|
||||
|
|
Loading…
Reference in New Issue