!44855 modify code in lite to adapt to new ge_adapt
Merge pull request !44855 from xiaoyao/ge_mindrt
This commit is contained in:
commit
85800a67d7
|
@ -939,7 +939,7 @@ std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(const FuncGrap
|
|||
// Create parameter
|
||||
for (const auto &node : func_graph->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
|
||||
MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
auto new_parameter = CreateNewParameter(node, graph.get());
|
||||
|
@ -951,7 +951,7 @@ std::shared_ptr<KernelGraph> KernelGraphMgr::ConstructKernelGraph(const FuncGrap
|
|||
if (node->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
|
||||
MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
|
||||
// Create value node
|
||||
if (node->isa<ValueNode>()) {
|
||||
// Create common value node
|
||||
|
|
|
@ -86,7 +86,7 @@ class DfGraphConvertor {
|
|||
} else {
|
||||
ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE);
|
||||
}
|
||||
is_kernel_graph_ = anf_graph_->type_name() == kKernelGraphClassName;
|
||||
is_kernel_graph_ = anf_graph_->type_name() == kKernelGraphTypeName;
|
||||
df_graph_ = std::make_shared<DfGraph>(anf_graph_->ToString());
|
||||
std::string graph_type = is_kernel_graph_ ? "kernel_graph" : "func_graph";
|
||||
std::string graph_name = anf_graph_->ToString();
|
||||
|
|
|
@ -105,7 +105,7 @@ bool IsWhileNode(const AnfNodePtr &node) {
|
|||
}
|
||||
auto graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
bool in_kg = graph->type_name() == kKernelGraphClassName;
|
||||
bool in_kg = graph->type_name() == kKernelGraphTypeName;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
ValueNodePtr graph_node = nullptr;
|
||||
if (in_kg && IsPrimitiveCNode(node, prim::kPrimCall) && cnode->input(1)->isa<ValueNode>()) {
|
||||
|
@ -176,7 +176,7 @@ bool IsIfNode(const AnfNodePtr &node) {
|
|||
}
|
||||
auto graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
bool in_kg = graph->type_name() == kKernelGraphClassName;
|
||||
bool in_kg = graph->type_name() == kKernelGraphTypeName;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
CNodePtr switch_node = nullptr;
|
||||
|
@ -222,7 +222,7 @@ bool IsCaseNode(const AnfNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
bool in_kg = graph->type_name() == kKernelGraphClassName;
|
||||
bool in_kg = graph->type_name() == kKernelGraphTypeName;
|
||||
if (in_kg && IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -804,7 +804,7 @@ void FuncGraph::set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forwa
|
|||
std::vector<AnfNodePtr> FuncGraph::TopoSort(const AnfNodePtr &node) { return mindspore::TopoSort(node); }
|
||||
|
||||
const std::vector<AnfNodePtr> &FuncGraph::inputs() const {
|
||||
if (type_name() == kFuncGraphClassName) {
|
||||
if (type_name() == kFuncGraphTypeName) {
|
||||
return parameters();
|
||||
} else {
|
||||
return *inputs_;
|
||||
|
|
|
@ -99,8 +99,8 @@ const char kFuncGraphFlagBackPropEntry[] = "back_prop_entry";
|
|||
const char kFuncGraphFlagReAutoMonad[] = "re_auto_monad";
|
||||
const char kFuncGraphFlagRecursive[] = "recursive";
|
||||
|
||||
const char kFuncGraphClassName[] = "FuncGraph";
|
||||
const char kKernelGraphClassName[] = "KernelGraph";
|
||||
const char kFuncGraphTypeName[] = "FuncGraph";
|
||||
const char kKernelGraphTypeName[] = "KernelGraph";
|
||||
|
||||
class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
||||
public:
|
||||
|
|
|
@ -30,15 +30,6 @@ namespace mindspore {
|
|||
namespace {
|
||||
constexpr auto kProviderGe = "ge";
|
||||
|
||||
std::string GetOriginFuncGraphName(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
KernelGraphPtr kg = std::dynamic_pointer_cast<session::KernelGraph>(graph);
|
||||
MS_EXCEPTION_IF_NULL(kg);
|
||||
FuncGraphPtr origin_graph = kg->GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(origin_graph);
|
||||
return origin_graph->ToString();
|
||||
}
|
||||
|
||||
void GetMeRetDataType(const AbstractBasePtr &cnode_data, std::vector<TypeId> *me_types) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_data);
|
||||
|
||||
|
@ -80,26 +71,6 @@ transform::TensorOrderMap GetParams(const FuncGraphPtr &anf_graph) {
|
|||
return res;
|
||||
}
|
||||
|
||||
void ReorderInputsAsFrontGraph(const KernelGraphPtr &kernel_graph, const FuncGraphPtr &origin_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const auto &front_map = kernel_graph->front_backend_anf_map();
|
||||
const auto &origin_parameters = origin_graph->get_inputs();
|
||||
std::vector<AnfNodePtr> new_parameters;
|
||||
|
||||
for (const auto ¶m : origin_parameters) {
|
||||
auto iter = front_map.find(param);
|
||||
if (iter == front_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid kernel graph " << kernel_graph->ToString() << " cannot find parameters "
|
||||
<< param->DebugString();
|
||||
}
|
||||
new_parameters.push_back(iter->second);
|
||||
}
|
||||
|
||||
kernel_graph->set_parameters(new_parameters);
|
||||
kernel_graph->SetGraphInputs(new_parameters);
|
||||
kernel_graph->SetInputNodes();
|
||||
}
|
||||
|
||||
bool AddDFGraph(const FuncGraphPtr &anf_graph, const transform::TensorOrderMap &init_inputs_map, bool export_air) {
|
||||
MS_EXCEPTION_IF_NULL(anf_graph);
|
||||
auto converter = transform::NewConverter(anf_graph);
|
||||
|
@ -243,22 +214,16 @@ bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &graph, const std::map<str
|
|||
MS_LOG(ERROR) << "Input param graph is nullptr.";
|
||||
return false;
|
||||
}
|
||||
KernelGraphPtr kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(graph);
|
||||
if (kernel_graph == nullptr) {
|
||||
KernelGraphPtr kg = std::dynamic_pointer_cast<session::KernelGraph>(graph);
|
||||
if (kg == nullptr) {
|
||||
MS_LOG(ERROR) << "Dynamic cast kernel graph failed.";
|
||||
return false;
|
||||
}
|
||||
FuncGraphPtr origin_graph = kernel_graph->GetFuncGraph();
|
||||
if (origin_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Origin graph of kernel failed.";
|
||||
return false;
|
||||
}
|
||||
ReorderInputsAsFrontGraph(kernel_graph, origin_graph);
|
||||
// opt::GeOptimization(origin_graph);
|
||||
(void)BuildDFGraph(origin_graph, GetParams(origin_graph), false);
|
||||
(void)BuildDFGraph(kg, GetParams(kg), false);
|
||||
kernel_graph->set_run_mode(device::RunMode::kGraphMode);
|
||||
// copy init weight to device
|
||||
RunGeInitGraph(origin_graph);
|
||||
RunGeInitGraph(kg);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -269,7 +234,8 @@ bool GeGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector<tens
|
|||
MS_LOG(ERROR) << " Input param is nullptr.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "GE run graph " << graph->ToString() << " start.";
|
||||
auto graph_name = graph->ToString();
|
||||
MS_LOG(INFO) << "GE run graph " << graph_name << " start.";
|
||||
std::vector<tensor::TensorPtr> input_tensors;
|
||||
for (const auto &input : inputs) {
|
||||
auto tensor = std::make_shared<tensor::Tensor>(input);
|
||||
|
@ -279,7 +245,7 @@ bool GeGraphExecutor::RunGraph(const FuncGraphPtr &graph, const std::vector<tens
|
|||
|
||||
// call ge rungraph
|
||||
transform::RunOptions run_options;
|
||||
run_options.name = GetOriginFuncGraphName(graph);
|
||||
run_options.name = graph_name;
|
||||
auto graph_runner = transform::GetGraphRunner();
|
||||
if (graph_runner == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Can not found GraphRunner.";
|
||||
|
|
|
@ -80,18 +80,21 @@ KernelGraphPtr KernelGraphUtils::ConstructKernelGraph(const FuncGraphPtr &func_g
|
|||
front_backend_graph_map_[func_graph.get()] = graph;
|
||||
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
||||
graph->set_device_target(device_target);
|
||||
// Create parameter
|
||||
for (const auto &node : func_graph->parameters()) {
|
||||
MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
auto new_parameter = CreateNewParameter(node, graph.get());
|
||||
graph_inputs->push_back(new_parameter);
|
||||
graph->FrontBackendMapAdd(node, new_parameter);
|
||||
}
|
||||
for (const auto &node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
|
||||
// Create parameter
|
||||
if (node->isa<Parameter>()) {
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
auto new_parameter = CreateNewParameter(node, graph.get());
|
||||
graph_inputs->push_back(new_parameter);
|
||||
graph->FrontBackendMapAdd(node, new_parameter);
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Start create new node, node = " << node->DebugString();
|
||||
// Create value node
|
||||
if (node->isa<ValueNode>()) {
|
||||
// Create common value node
|
||||
|
|
Loading…
Reference in New Issue