!44855 modify code in lite to adapt to new ge_adapt

Merge pull request !44855 from xiaoyao/ge_mindrt
This commit is contained in:
i-robot 2022-11-01 08:34:43 +00:00 committed by Gitee
commit 85800a67d7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 26 additions and 57 deletions

View File

@ -939,7 +939,7 @@ std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(const FuncGrap
// Create parameter
for (const auto &node : func_graph->parameters()) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto new_parameter = CreateNewParameter(node, graph.get());
@ -951,7 +951,7 @@ std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(const FuncGrap
if (node->isa<Parameter>()) {
continue;
}
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
// Create value node
if (node->isa<ValueNode>()) {
// Create common value node

View File

@ -86,7 +86,7 @@ class DfGraphConvertor {
} else {
ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE);
}
is_kernel_graph_ = anf_graph_->type_name() == kKernelGraphClassName;
is_kernel_graph_ = anf_graph_->type_name() == kKernelGraphTypeName;
df_graph_ = std::make_shared<DfGraph>(anf_graph_->ToString());
std::string graph_type = is_kernel_graph_ ? "kernel_graph" : "func_graph";
std::string graph_name = anf_graph_->ToString();

View File

@ -105,7 +105,7 @@ bool IsWhileNode(const AnfNodePtr &node) {
}
auto graph = node->func_graph();
MS_EXCEPTION_IF_NULL(graph);
bool in_kg = graph->type_name() == kKernelGraphClassName;
bool in_kg = graph->type_name() == kKernelGraphTypeName;
auto cnode = node->cast<CNodePtr>();
ValueNodePtr graph_node = nullptr;
if (in_kg && IsPrimitiveCNode(node, prim::kPrimCall) && cnode->input(1)->isa<ValueNode>()) {
@ -176,7 +176,7 @@ bool IsIfNode(const AnfNodePtr &node) {
}
auto graph = node->func_graph();
MS_EXCEPTION_IF_NULL(graph);
bool in_kg = graph->type_name() == kKernelGraphClassName;
bool in_kg = graph->type_name() == kKernelGraphTypeName;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CNodePtr switch_node = nullptr;
@ -222,7 +222,7 @@ bool IsCaseNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(cnode);
auto graph = node->func_graph();
MS_EXCEPTION_IF_NULL(graph);
bool in_kg = graph->type_name() == kKernelGraphClassName;
bool in_kg = graph->type_name() == kKernelGraphTypeName;
if (in_kg && IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
return true;
}

View File

@ -804,7 +804,7 @@ void FuncGraph::set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forwa
std::vector<AnfNodePtr> FuncGraph::TopoSort(const AnfNodePtr &node) { return mindspore::TopoSort(node); }
const std::vector<AnfNodePtr> &FuncGraph::inputs() const {
if (type_name() == kFuncGraphClassName) {
if (type_name() == kFuncGraphTypeName) {
return parameters();
} else {
return *inputs_;

View File

@ -99,8 +99,8 @@ const char kFuncGraphFlagBackPropEntry[] = "back_prop_entry";
const char kFuncGraphFlagReAutoMonad[] = "re_auto_monad";
const char kFuncGraphFlagRecursive[] = "recursive";
const char kFuncGraphClassName[] = "FuncGraph";
const char kKernelGraphClassName[] = "KernelGraph";
const char kFuncGraphTypeName[] = "FuncGraph";
const char kKernelGraphTypeName[] = "KernelGraph";
class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
public:

View File

@ -30,15 +30,6 @@ namespace mindspore {
namespace {
constexpr auto kProviderGe = "ge";
std::string GetOriginFuncGraphName(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
KernelGraphPtr kg = std::dynamic_pointer_cast<session::KernelGraph>(graph);
MS_EXCEPTION_IF_NULL(kg);
FuncGraphPtr origin_graph = kg->GetFuncGraph();
MS_EXCEPTION_IF_NULL(origin_graph);
return origin_graph->ToString();
}
void GetMeRetDataType(const AbstractBasePtr &cnode_data, std::vector<TypeId> *me_types) {
MS_EXCEPTION_IF_NULL(cnode_data);
@ -80,26 +71,6 @@ transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
return res;
}
void ReorderInputsAsFrontGraph(const KernelGraphPtr &kernel_graph, const FuncGraphPtr &origin_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
const auto &front_map = kernel_graph->front_backend_anf_map();
const auto &origin_parameters = origin_graph->get_inputs();
std::vector<AnfNodePtr> new_parameters;
for (const auto &param : origin_parameters) {
auto iter = front_map.find(param);
if (iter == front_map.end()) {
MS_LOG(EXCEPTION) << "Invalid kernel graph " << kernel_graph->ToString() << " cannot find parameters "
<< param->DebugString();
}
new_parameters.push_back(iter->second);
}
kernel_graph->set_parameters(new_parameters);
kernel_graph->SetGraphInputs(new_parameters);
kernel_graph->SetInputNodes();
}
bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, bool export_air) {
MS_EXCEPTION_IF_NULL(anf_graph);
auto converter = transform::NewConverter(anf_graph);
@ -243,22 +214,16 @@ bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &graph, const std::map<str
MS_LOG(ERROR) << "Input param graph is nullptr.";
return false;
}
KernelGraphPtr kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(graph);
if (kernel_graph == nullptr) {
KernelGraphPtr kg = std::dynamic_pointer_cast<session::KernelGraph>(graph);
if (kg == nullptr) {
MS_LOG(ERROR) << "Dynamic cast kernel graph failed.";
return false;
}
FuncGraphPtr origin_graph = kernel_graph->GetFuncGraph();
if (origin_graph == nullptr) {
MS_LOG(ERROR) << "Origin graph of kernel failed.";
return false;
}
ReorderInputsAsFrontGraph(kernel_graph, origin_graph);
// opt::GeOptimization(origin_graph);
(void)BuildDFGraph(origin_graph, GetParams(origin_graph), false);
(void)BuildDFGraph(kg, GetParams(kg), false);
kernel_graph->set_run_mode(device::RunMode::kGraphMode);
// copy init weight to device
RunGeInitGraph(origin_graph);
RunGeInitGraph(kg);
return true;
}
@ -269,7 +234,8 @@ bool GeGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector<tens
MS_LOG(ERROR) << " Input param is nullptr.";
return false;
}
MS_LOG(INFO) << "GE run graph " << graph->ToString() << " start.";
auto graph_name = graph->ToString();
MS_LOG(INFO) << "GE run graph " << graph_name << " start.";
std::vector<tensor::TensorPtr> input_tensors;
for (const auto &input : inputs) {
auto tensor = std::make_shared<tensor::Tensor>(input);
@ -279,7 +245,7 @@ bool GeGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector<tens
// call ge rungraph
transform::RunOptions run_options;
run_options.name = GetOriginFuncGraphName(graph);
run_options.name = graph_name;
auto graph_runner = transform::GetGraphRunner();
if (graph_runner == nullptr) {
MS_LOG(EXCEPTION) << "Can not found GraphRunner.";

View File

@ -80,18 +80,21 @@ KernelGraphPtr KernelGraphUtils::ConstructKernelGraph(const FuncGraphPtr &func_g
front_backend_graph_map_[func_graph.get()] = graph;
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
graph->set_device_target(device_target);
// Create parameter
for (const auto &node : func_graph->parameters()) {
MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto new_parameter = CreateNewParameter(node, graph.get());
graph_inputs->push_back(new_parameter);
graph->FrontBackendMapAdd(node, new_parameter);
}
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
// Create parameter
if (node->isa<Parameter>()) {
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto new_parameter = CreateNewParameter(node, graph.get());
graph_inputs->push_back(new_parameter);
graph->FrontBackendMapAdd(node, new_parameter);
continue;
}
MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
// Create value node
if (node->isa<ValueNode>()) {
// Create common value node