forked from mindspore-Ecosystem/mindspore
actor runtime support GraphKrenel
This commit is contained in:
parent
f9d5a813e2
commit
7f634d12f0
|
@ -297,6 +297,58 @@ std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) {
|
||||
std::vector<KernelWithIndex> ret;
|
||||
std::vector<KernelWithIndex> ret_empty;
|
||||
|
||||
// The MakeTuple node need expand and recurse.
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
|
||||
auto input_i_vector = GetAllOutputWithIndex(make_tuple->input(i));
|
||||
(void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto outputs_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
if (!IsRealCNodeKernel(node)) {
|
||||
outputs_num = 1;
|
||||
}
|
||||
// The output may be the tuple, so need visit all the outputs of node.
|
||||
for (size_t i = 0; i < outputs_num; ++i) {
|
||||
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(node, i, false);
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
|
||||
// The MakeTuple node need recurse.
|
||||
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) {
|
||||
auto input_vector = GetAllOutputWithIndex(output_with_index.first);
|
||||
(void)std::copy(input_vector.begin(), input_vector.end(), std::back_inserter(ret));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ignore the output of front call node.
|
||||
if (output_with_index.first->isa<CNode>()) {
|
||||
auto cnode = output_with_index.first->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs[0]->isa<CNode>()) {
|
||||
MS_LOG(INFO) << "The output is call node: " << output_with_index.first->DebugString();
|
||||
return ret_empty;
|
||||
}
|
||||
}
|
||||
|
||||
// The InitDataSetQueue node has no output.
|
||||
if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimInitDataSetQueue)) {
|
||||
return ret_empty;
|
||||
}
|
||||
|
||||
ret.push_back(output_with_index);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return node->input(kAnfPrimitiveIndex);
|
||||
|
|
|
@ -72,6 +72,7 @@ class AnfRuntimeAlgorithm {
|
|||
prim::kPrimMakeTuple});
|
||||
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
|
||||
const std::vector<PrimitivePtr> &return_types = {});
|
||||
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);
|
||||
// get cnode primitive
|
||||
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
|
||||
static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
|
||||
|
|
|
@ -1106,7 +1106,7 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
|
|||
MS_LOG(INFO) << "Node is not internal output";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString();
|
||||
MS_LOG(INFO) << "Replace internal output node " << node->DebugString() << " to " << new_node->DebugString();
|
||||
auto &front_nodes = iter->second;
|
||||
// Move specified front node to new node mapping
|
||||
auto front_node_iter = front_nodes.find(src_output_idx);
|
||||
|
@ -1139,6 +1139,85 @@ AnfWithOutIndex KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr &p
|
|||
return AnfWithOutIndex();
|
||||
}
|
||||
|
||||
void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output,
|
||||
const AnfNodePtr &front_node) {
|
||||
if ((backend_graph_output == nullptr) || (front_node == nullptr)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto backend_outputs = AnfAlgo::GetAllOutputWithIndex(backend_graph_output);
|
||||
auto front_outputs = AnfAlgo::GetAllOutputWithIndex(front_node);
|
||||
if (backend_outputs.size() != front_outputs.size()) {
|
||||
MS_LOG(INFO) << "The size(" << backend_outputs.size()
|
||||
<< ") of backend output: " << backend_graph_output->DebugString() << " is not equal to the size("
|
||||
<< front_outputs.size() << ") of front output: " << front_node->DebugString();
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < backend_outputs.size(); ++i) {
|
||||
auto backend_output = backend_outputs[i];
|
||||
auto front_output = front_outputs[i];
|
||||
graph_output_to_front_node_map_[backend_output] = front_output;
|
||||
MS_LOG(INFO) << "Backend output: " << backend_output.first->fullname_with_scope()
|
||||
<< " with index: " << backend_output.second
|
||||
<< " map to front node: " << front_output.first->fullname_with_scope()
|
||||
<< " with index: " << front_output.second;
|
||||
}
|
||||
}
|
||||
|
||||
AnfWithOutIndex KernelGraph::GetFrontNodeWithIndexByGraphOutput(
|
||||
const AnfWithOutIndex &backend_graph_output_with_index) const {
|
||||
const auto &iter = graph_output_to_front_node_map_.find(backend_graph_output_with_index);
|
||||
if (iter != graph_output_to_front_node_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return AnfWithOutIndex();
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateGraphOutputMap(const std::vector<AnfWithOutIndex> &old_outputs,
|
||||
const std::vector<AnfWithOutIndex> &new_outputs) {
|
||||
MS_LOG(INFO) << "The size of old outputs: " << old_outputs.size()
|
||||
<< ", the size of new outputs: " << new_outputs.size();
|
||||
if (old_outputs.size() != new_outputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of old outputs is not equal to the size of new outputs.";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < old_outputs.size(); ++i) {
|
||||
auto old_output = old_outputs[i];
|
||||
auto new_output = new_outputs[i];
|
||||
if (old_output == new_output) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Update the graph output map.
|
||||
if (graph_output_to_front_node_map_.count(old_output) > 0) {
|
||||
MS_LOG(INFO) << "Replace backend output node " << old_output.first->fullname_with_scope() << " with index "
|
||||
<< old_output.second << " to " << new_output.first->fullname_with_scope() << " with index "
|
||||
<< new_output.second;
|
||||
graph_output_to_front_node_map_[new_output] = graph_output_to_front_node_map_[old_output];
|
||||
graph_output_to_front_node_map_.erase(old_output);
|
||||
}
|
||||
|
||||
// Update the internal output map.
|
||||
if (IsInternalOutput(old_output.first, old_output.second)) {
|
||||
ReplaceInternalOutput(old_output.first, new_output.first, old_output.second, new_output.second);
|
||||
}
|
||||
|
||||
if (old_output.first == new_output.first) {
|
||||
continue;
|
||||
}
|
||||
// Update the front backend node map.
|
||||
if (backend_front_anf_map_.count(old_output.first) > 0) {
|
||||
MS_LOG(INFO) << "Replace backend output node " << old_output.first->fullname_with_scope() << " to "
|
||||
<< new_output.first->fullname_with_scope();
|
||||
auto front_node = backend_front_anf_map_[old_output.first];
|
||||
front_backend_anf_map_[front_node] = new_output.first;
|
||||
backend_front_anf_map_[new_output.first] = front_node;
|
||||
(void)backend_front_anf_map_.erase(old_output.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
|
||||
auto iter = front_to_internal_outputs_map_.find(front_node);
|
||||
if (iter != front_to_internal_outputs_map_.end()) {
|
||||
|
|
|
@ -72,6 +72,7 @@ class KernelGraph : public FuncGraph {
|
|||
start_label_ = graph.start_label_;
|
||||
end_goto_ = graph.end_goto_;
|
||||
internal_parameter_to_front_node_map_ = graph.internal_parameter_to_front_node_map_;
|
||||
graph_output_to_front_node_map_ = graph.graph_output_to_front_node_map_;
|
||||
front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_;
|
||||
internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_;
|
||||
internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_;
|
||||
|
@ -206,9 +207,6 @@ class KernelGraph : public FuncGraph {
|
|||
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, size_t src_output_idx,
|
||||
size_t dst_output_idx);
|
||||
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node);
|
||||
// Cache the internal parameter and corresponding to front node into internal_parameter_to_front_node_map_.
|
||||
void CacheInternalParameterToFrontNode(const AnfNodePtr ¶meter, const AnfWithOutIndex &front_node_with_index);
|
||||
AnfWithOutIndex GetFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const;
|
||||
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
|
||||
bool IsInternalOutput(const AnfNodePtr &node, size_t output_idx) const;
|
||||
bool IsInternalOutput(const AnfNodePtr &node) const;
|
||||
|
@ -216,6 +214,18 @@ class KernelGraph : public FuncGraph {
|
|||
void AddInternalOutputTensor(const AnfNodePtr &node, size_t output_idx, const tensor::TensorPtr &tensor);
|
||||
tensor::TensorPtr GetInternalOutputTensor(const AnfNodePtr &node, size_t output_idx);
|
||||
|
||||
// Cache the internal parameter and corresponding to front node into internal_parameter_to_front_node_map_.
|
||||
void CacheInternalParameterToFrontNode(const AnfNodePtr ¶meter, const AnfWithOutIndex &front_node_with_index);
|
||||
AnfWithOutIndex GetFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const;
|
||||
|
||||
// Cache the backend graph output nodes and corresponding to front nodes with output index into
|
||||
// graph_output_to_front_node_map_.
|
||||
void CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output, const AnfNodePtr &front_node);
|
||||
AnfWithOutIndex GetFrontNodeWithIndexByGraphOutput(const AnfWithOutIndex &backend_graph_output_with_index) const;
|
||||
// Update the related map of backend graph output nodes by modified backend output nodes.
|
||||
void UpdateGraphOutputMap(const std::vector<AnfWithOutIndex> &old_outputs,
|
||||
const std::vector<AnfWithOutIndex> &new_outputs);
|
||||
|
||||
uint32_t current_epoch() const { return current_epoch_; }
|
||||
void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
|
||||
void UpdateChildGraphOrder();
|
||||
|
@ -376,10 +386,15 @@ class KernelGraph : public FuncGraph {
|
|||
|
||||
CNodePtr start_label_;
|
||||
CNodePtr end_goto_;
|
||||
|
||||
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
|
||||
// related to the input of this kernel graph. The first of unordered map is the input of this kernel graph, the second
|
||||
// of unordered map is front node corresponding to the output of previous kernel graph.
|
||||
std::unordered_map<AnfNodePtr, AnfWithOutIndex> internal_parameter_to_front_node_map_;
|
||||
// The first of map is the backend graph output of this kernel graph, the second of map is front node corresponding to
|
||||
// the backend node with index.
|
||||
std::map<AnfWithOutIndex, AnfWithOutIndex> graph_output_to_front_node_map_;
|
||||
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
|
||||
std::unordered_map<AnfNodePtr, std::unordered_map<size_t, std::pair<AnfNodePtr, bool>>>
|
||||
internal_outputs_to_front_map_;
|
||||
|
|
|
@ -1919,6 +1919,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
|
|||
auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr {
|
||||
auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
|
||||
if (backend_anf != nullptr) {
|
||||
graph->CacheGraphOutputToFrontNodeWithIndex(backend_anf, out);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
|
|
|
@ -266,7 +266,11 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
||||
// Execute optimization pass.
|
||||
auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
device_context->OptimizeGraph(graph);
|
||||
auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
// Update the output map of kernel graph by modified output nodes.
|
||||
graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer);
|
||||
|
||||
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
||||
// 'KernelMod' is real executive object of kernel.
|
||||
|
|
|
@ -74,6 +74,17 @@ AnfNodePtr FetchFrontNodeByBackendNode(const AnfNodePtr &backend_node, const Ker
|
|||
return front_node;
|
||||
}
|
||||
|
||||
KernelWithIndex FetchFrontNodeWithIndexByGraphOutput(const KernelWithIndex &output_with_index,
|
||||
const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto front_node_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
|
||||
// PyNative forward graph does not has front node, using backend node instead.
|
||||
if (front_node_with_index.first == nullptr) {
|
||||
front_node_with_index = output_with_index;
|
||||
}
|
||||
return front_node_with_index;
|
||||
}
|
||||
|
||||
// The branch processing of PrepareDataForValueNode that value type is tensor.
|
||||
void PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
|
||||
const DeviceContext *device_context) {
|
||||
|
@ -649,18 +660,16 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
|
|||
void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
for (const auto &graph : graph_compiler_info.graphs_) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const auto &outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||
for (const auto &output : outputs) {
|
||||
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, false);
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
for (const auto &output_with_index : outputs) {
|
||||
auto output_kernel = output_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(output_kernel);
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(output_kernel);
|
||||
if (front_node == nullptr) {
|
||||
auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
|
||||
if (origin_output_with_index.first == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto actor_output_index = output_with_index.second;
|
||||
auto origin_output_with_index = KernelWithIndex(front_node, actor_output_index);
|
||||
OpActor<DeviceTensor> *actor = nullptr;
|
||||
if (IsKernelActor(output_kernel)) {
|
||||
actor = FetchActor(output_kernel->fullname_with_scope());
|
||||
|
@ -684,7 +693,8 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
|
|||
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
|
||||
<< " to actor:" << actor->GetAID().Name() << " with output index:" << actor_output_index;
|
||||
<< " with index: " << output_with_index.second << " to actor:" << actor->GetAID().Name()
|
||||
<< " with index:" << actor_output_index;
|
||||
graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(actor, actor_output_index));
|
||||
}
|
||||
}
|
||||
|
@ -1302,20 +1312,22 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
|
|||
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
|
||||
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
|
||||
for (const auto &real_depend_input : real_depend_inputs) {
|
||||
auto real_depend_input_with_idx = AnfAlgo::VisitKernelWithReturnType(real_depend_input, 0, false, return_types);
|
||||
auto real_depend_kernel = real_depend_input_with_idx.first;
|
||||
// The monad node and make tuple node need recursion.
|
||||
if (AnfAlgo::IsOneOfPrimitiveCNode(real_depend_input, recursion_prims)) {
|
||||
LinkControlArrowByAutoMonad(to_actor, real_depend_input);
|
||||
if (AnfAlgo::IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
|
||||
LinkControlArrowByAutoMonad(to_actor, real_depend_kernel);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!IsKernelActor(real_depend_input)) {
|
||||
if (!IsKernelActor(real_depend_kernel)) {
|
||||
continue;
|
||||
}
|
||||
// Link the control arrow between the kernel actors.
|
||||
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_input->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_LOG(INFO) << "Link control arrow by auto monad, from actor: " << from_actor->GetAID().Name()
|
||||
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_kernel->fullname_with_scope()));
|
||||
MS_LOG(INFO) << "Link control arrow by auto monad, from actor: " << real_depend_kernel->fullname_with_scope()
|
||||
<< ", to actor: " << to_actor->GetAID().Name();
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
|
||||
to_actor->input_controls_num_++;
|
||||
}
|
||||
|
@ -1427,12 +1439,10 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
for (const auto &graph : graph_compiler_info.graphs_) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
++number;
|
||||
const auto &outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||
for (const auto &output : outputs) {
|
||||
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, false);
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
for (const auto &output_with_index : outputs) {
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
const auto &front_node = FetchFrontNodeByBackendNode(output_with_index.first, graph);
|
||||
auto origin_output_with_index = KernelWithIndex(front_node, output_with_index.second);
|
||||
auto origin_output_with_index = FetchFrontNodeWithIndexByGraphOutput(output_with_index, graph);
|
||||
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
|
||||
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
|
||||
continue;
|
||||
|
@ -1618,12 +1628,12 @@ void GraphScheduler::LinkDataArrowByCallInput(const GraphCompilerInfo &graph_com
|
|||
// Collect the output of each funcgraph.
|
||||
for (const auto &func_graph : func_graphs) {
|
||||
// The output of funcgraph can only have one.
|
||||
const auto &outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem});
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(func_graph->output());
|
||||
if (outputs.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output of func graph is more than one, func graph:" << func_graph->ToString();
|
||||
}
|
||||
|
||||
auto output_with_index = AnfAlgo::VisitKernelWithReturnType(outputs[0], 0);
|
||||
auto output_with_index = outputs[0];
|
||||
if (IsKernelActor(output_with_index.first)) {
|
||||
// Input is a kernel actor.
|
||||
const auto &iter = front_node_to_actor_.find(output_with_index.first);
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "utils/trace_base.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/pass_manager.h"
|
||||
#include "backend/optimizer/common/common_backend_optimization.h"
|
||||
#include "backend/optimizer/cpu/insert_cast_cpu.h"
|
||||
#include "backend/optimizer/cpu/insert_format_transform_op.h"
|
||||
#include "backend/optimizer/pass/replace_node_by_proxy.h"
|
||||
|
@ -75,6 +76,9 @@ void CPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const {
|
|||
SetOperatorInfo(graph->execution_order());
|
||||
OptimizeGraphImpl(graph);
|
||||
|
||||
// Run final optimization.
|
||||
opt::CommonFinalOptimization(graph);
|
||||
|
||||
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
|
||||
auto execution_order = graph->execution_order();
|
||||
AnfAlgo::ReorderPosteriorExecList(NOT_NULL(&execution_order));
|
||||
|
|
|
@ -184,11 +184,15 @@ void GPUDeviceContext::OptimizeGraph(const KernelGraphPtr &graph) const {
|
|||
// Optimization pass which is irrelevant to device type or format.
|
||||
OptimizeGraphWithoutDeviceInfo(graph);
|
||||
|
||||
FormatTransformChecker::GetInstance().CheckSupportFormatTransform(graph);
|
||||
SetOperatorInfo(graph->execution_order());
|
||||
|
||||
// Optimization pass which is relevant to device type or format.
|
||||
OptimizeGraphWithDeviceInfo(graph);
|
||||
|
||||
// Run final optimization.
|
||||
opt::CommonFinalOptimization(graph);
|
||||
|
||||
// Graph kernel fusion optimization
|
||||
if (context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
opt::GraphKernelOptimize(graph);
|
||||
|
@ -270,6 +274,7 @@ void GPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr>
|
|||
|
||||
void GPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
FormatTransformChecker::GetInstance().CheckSupportFormatTransform(graph);
|
||||
SetOperatorInfo(graph->execution_order());
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
|
|
|
@ -514,30 +514,10 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
const auto &all_branch_output = ControlNodeParser::FetchAllBranchOutputs(root_graph);
|
||||
for (const auto &branch_output : all_branch_output) {
|
||||
size_t position = 0;
|
||||
if (AnfAlgo::CheckPrimitiveType(branch_output, prim::kPrimMakeTuple)) {
|
||||
const auto &outputs = AnfAlgo::GetAllOutput(branch_output, {prim::kPrimTupleGetItem});
|
||||
outputs_num = outputs.size();
|
||||
|
||||
for (const auto &output : outputs) {
|
||||
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
// The InitDataSetQueue kernel has no output.
|
||||
if (AnfAlgo::GetCNodeName(output_with_index.first) == kInitDatasetQueueOpName) {
|
||||
continue;
|
||||
}
|
||||
outputs_order.emplace(output_with_index, position++);
|
||||
}
|
||||
} else if (branch_output->isa<CNode>()) {
|
||||
outputs_num = AnfAlgo::GetOutputTensorNum(branch_output);
|
||||
for (size_t i = 0; i < outputs_num; i++) {
|
||||
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(branch_output, i, false);
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
// The InitDataSetQueue kernel has no output.
|
||||
if (AnfAlgo::GetCNodeName(output_with_index.first) == kInitDatasetQueueOpName) {
|
||||
continue;
|
||||
}
|
||||
outputs_order.emplace(output_with_index, position++);
|
||||
}
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(branch_output);
|
||||
outputs_num = outputs.size();
|
||||
for (const auto &output : outputs) {
|
||||
outputs_order.emplace(output, position++);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -580,17 +560,12 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
|
|||
device_contexts.emplace_back(graph_info_to_context.second);
|
||||
name.append(graph_info_to_context.first);
|
||||
|
||||
const auto &outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
for (const auto &output : outputs) {
|
||||
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
// The InitDataSetQueue kernel has no output.
|
||||
if (AnfAlgo::GetCNodeName(output_with_index.first) == kInitDatasetQueueOpName) {
|
||||
continue;
|
||||
}
|
||||
outputs_order.emplace(output_with_index, position++);
|
||||
outputs_order.emplace(output, position++);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t> *> tensors_mask_list(1, const_cast<std::vector<int64_t> *>(tensors_mask));
|
||||
std::vector<std::vector<TensorPtr> *> input_tensors_list(1,
|
||||
const_cast<std::vector<tensor::TensorPtr> *>(input_tensors));
|
||||
|
|
|
@ -372,6 +372,7 @@ inline const PrimitivePtr kPrimMemCpyAsync = std::make_shared<Primitive>("memcpy
|
|||
inline const PrimitivePtr kPrimFill = std::make_shared<Primitive>("Fill");
|
||||
inline const PrimitivePtr kPrimFusedPushWeight = std::make_shared<Primitive>("FusedPushWeight");
|
||||
inline const PrimitivePtr kPrimFusedPullWeight = std::make_shared<Primitive>("FusedPullWeight");
|
||||
inline const PrimitivePtr kPrimInitDataSetQueue = std::make_shared<Primitive>("InitDataSetQueue");
|
||||
|
||||
// Quant ops
|
||||
inline const PrimitivePtr kPrimBatchNormFold = std::make_shared<Primitive>("BatchNormFold");
|
||||
|
|
Loading…
Reference in New Issue