forked from OSSInnovation/mindspore
add graph compiler
This commit is contained in:
parent
d2ee376b7b
commit
c57a9edb9d
|
@ -128,6 +128,13 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
// Get graph by graph id, if not exist return null ptr
|
||||
KernelGraphPtr GetGraph(GraphId graph_id) const;
|
||||
void ClearGraph();
|
||||
// create a single run op graph
|
||||
std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask, bool is_ascend = false);
|
||||
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors);
|
||||
void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
|
||||
void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const;
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
// set debugger
|
||||
void SetDebugger() {
|
||||
|
@ -163,12 +170,12 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
VectorRef *outputs,
|
||||
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node);
|
||||
virtual void UnifyMindIR(const KernelGraphPtr &graph) = 0;
|
||||
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
|
||||
virtual void UnifyMindIR(const KernelGraphPtr &graph) {}
|
||||
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; }
|
||||
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
|
||||
virtual void BuildGraphImpl(GraphId) {}
|
||||
virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) = 0;
|
||||
virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
|
||||
}
|
||||
virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {}
|
||||
|
@ -183,7 +190,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
|
||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors);
|
||||
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) const;
|
||||
void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const;
|
||||
|
@ -191,10 +197,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
// create graph output for RunOp
|
||||
void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph);
|
||||
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
|
||||
// create a single run op graph
|
||||
std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask, bool is_ascend = false);
|
||||
// Generate graph info for a single op graph
|
||||
GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors);
|
||||
void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info);
|
||||
|
@ -219,8 +221,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
|
||||
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
|
||||
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
|
||||
void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
|
||||
void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const;
|
||||
virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; }
|
||||
void InitAllBucket(const KernelGraphPtr &graph);
|
||||
void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"){}
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/framework/graph_compiler.h"
|
||||
#include "runtime/framework/graph_scheduler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void GraphCompiler::set_device_context(device::DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
device_context_ = device_context;
|
||||
|
||||
// The member variable 'session_' will be removed after removing session module.
|
||||
if (session_ == nullptr) {
|
||||
session_ = std::make_shared<session::SessionBasic>();
|
||||
}
|
||||
}
|
||||
|
||||
GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
// Generate kernel graph.
|
||||
auto graph = session_->ConstructKernelGraph(nodes, outputs);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
return CompileGraphImpl(graph);
|
||||
}
|
||||
|
||||
GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
// Optimization pass which is irrelevant to device type or format.
|
||||
device_context_->OptimizeGraphWithoutDeviceInfo(graph);
|
||||
|
||||
device_context_->SetOperatorInfo(graph->execution_order());
|
||||
|
||||
// Optimization pass which is relevant to device type or format.
|
||||
device_context_->OptimizeGraphWithDeviceInfo(graph);
|
||||
|
||||
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
||||
// 'KernelMod' is real executive object of kernel.
|
||||
device_context_->CreateKernel(graph->execution_order());
|
||||
|
||||
// Transform graph to actor DAG, contains build and link.
|
||||
GraphScheduler::GetInstance().Transform(graph, device_context_);
|
||||
return graph->graph_id();
|
||||
}
|
||||
|
||||
void GraphCompiler::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
auto graph = session_->GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto actor_set = GraphScheduler::GetInstance().Fetch(graph);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
GraphScheduler::GetInstance().Run(actor_set);
|
||||
}
|
||||
|
||||
void GraphCompiler::CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask, VectorRef *outputs) {
|
||||
// Check if the graph cache exists.
|
||||
if (run_op_graphs_.find(graph_info) == run_op_graphs_.end()) {
|
||||
// Prepare the graph
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
device_context_->SetOperatorInfo(graph->execution_order());
|
||||
|
||||
device_context_->OptimizeSingleOpGraph(graph);
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->RunOpHideNopNode(graph);
|
||||
|
||||
device_context_->CreateKernel(graph->execution_order());
|
||||
run_op_graphs_[graph_info] = graph;
|
||||
}
|
||||
|
||||
session_->EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
// wait for allreduce
|
||||
for (auto &tensor : *input_tensors) {
|
||||
if (tensor->NeedWaitDevice()) {
|
||||
tensor->WaitDevice();
|
||||
}
|
||||
}
|
||||
|
||||
// run op
|
||||
auto graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
session_->RunOpRemoveNopNode(graph);
|
||||
|
||||
GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep);
|
||||
auto actor_set = GraphScheduler::GetInstance().Fetch(graph);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
GraphScheduler::GetInstance().Run(actor_set, GraphExecutionStrategy::kStep);
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -45,7 +45,7 @@ class GraphCompiler {
|
|||
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
|
||||
// Construct single op kernel graph, compile and run the kernel graph in PyNative mode.
|
||||
void CompileAndRunGraph(OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
void CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, const std::vector<int64_t> &tensors_mask,
|
||||
VectorRef *outputs);
|
||||
|
||||
|
@ -61,7 +61,7 @@ class GraphCompiler {
|
|||
device::DeviceContext *device_context_{nullptr};
|
||||
|
||||
// Single op kernel graph cache for PyNative mode.
|
||||
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
||||
std::unordered_map<GraphInfo, KernelGraphPtr> run_op_graphs_;
|
||||
|
||||
// The member variable 'session_' will be removed after removing session module.
|
||||
session::SessionPtr session_{nullptr};
|
||||
|
|
|
@ -21,6 +21,11 @@
|
|||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "runtime/device/cpu/kernel_select_cpu.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/pass_manager.h"
|
||||
#include "backend/optimizer/cpu/insert_cast_cpu.h"
|
||||
#include "backend/optimizer/pass/replace_node_by_proxy.h"
|
||||
#include "backend/optimizer/pass/erase_visit_attr.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -45,6 +50,40 @@ void CPUDeviceContext::FreeMemory(DeviceAddress *const &address) const {
|
|||
address->ptr_ = nullptr;
|
||||
}
|
||||
|
||||
void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {
|
||||
// Update Graph Dynamic Shape Attr.
|
||||
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
|
||||
|
||||
OptimizeGraphImpl(graph);
|
||||
|
||||
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
|
||||
auto execution_order = graph->execution_order();
|
||||
AnfAlgo::ReorderPosteriorExecList(NOT_NULL(&execution_order));
|
||||
graph->set_execution_order(execution_order);
|
||||
}
|
||||
|
||||
void CPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const { OptimizeGraphImpl(graph); }
|
||||
|
||||
void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::InsertCastCPU>());
|
||||
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void CPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const {
|
||||
for (const auto &cnode : graph->execution_order()) {
|
||||
if (AnfAlgo::IsNodeDynamicShape(cnode)) {
|
||||
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode);
|
||||
MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
graph->UpdateGraphDynamicAttr();
|
||||
}
|
||||
|
||||
void CPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {
|
||||
for (const auto &node : nodes) {
|
||||
SetKernelInfo(node);
|
||||
|
|
|
@ -36,15 +36,23 @@ class CPUDeviceContext : public DeviceContext {
|
|||
bool AllocateMemory(DeviceAddress *const &address, size_t size) const override;
|
||||
void FreeMemory(DeviceAddress *const &address) const override;
|
||||
|
||||
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override;
|
||||
void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override;
|
||||
|
||||
void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override;
|
||||
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
|
||||
|
||||
bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override;
|
||||
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(CPUDeviceContext);
|
||||
|
||||
// Update Graph Dynamic Shape Attr.
|
||||
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const;
|
||||
|
||||
void OptimizeGraphImpl(const KernelGraphPtr &graph) const;
|
||||
|
||||
uint32_t device_id_;
|
||||
std::shared_ptr<MemoryManager> mem_manager_;
|
||||
bool initialized_;
|
||||
};
|
||||
|
|
|
@ -63,17 +63,23 @@ class DeviceContext {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Optimize the kernel graph according to different devices.
|
||||
virtual void OptimizeGraph(const KernelGraphPtr &graph) const {}
|
||||
// The two functions below will be merged to one in the future.
|
||||
// General graph optimezer ignore device data type and format.
|
||||
virtual void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {}
|
||||
// Optimize the kernel graph according to device data type and format.
|
||||
virtual void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const {}
|
||||
|
||||
// Optimize the single operator graph for PyNative mode.
|
||||
virtual void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {}
|
||||
|
||||
// Select the matching backend kernels according to the data type and format of input and output for all
|
||||
// execution operators, and set final device data type and format information for backend kernels, device
|
||||
// data type and format which replace original data type and format will use for executing kernels.
|
||||
virtual void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {}
|
||||
virtual void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const = 0;
|
||||
|
||||
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
||||
// 'KernelMod' is real executive object of kernel.
|
||||
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {}
|
||||
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const = 0;
|
||||
|
||||
// Launch a kernel via 'KernelMod' of the kernel.
|
||||
virtual bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,
|
||||
|
|
|
@ -34,7 +34,7 @@ void DeviceContextManager::ClearDeviceContexts() {
|
|||
device_contexts_.clear();
|
||||
}
|
||||
|
||||
DeviceContext *DeviceContextManager::GetDeviceContext(const DeviceContextKey &device_context_key) {
|
||||
DeviceContext *DeviceContextManager::CreateOrGetDeviceContext(const DeviceContextKey &device_context_key) {
|
||||
std::string device_context_key_str = device_context_key.ToString();
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class DeviceContextManager {
|
|||
return instance;
|
||||
}
|
||||
void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator);
|
||||
DeviceContext *GetDeviceContext(const DeviceContextKey &device_info);
|
||||
DeviceContext *CreateOrGetDeviceContext(const DeviceContextKey &device_context_key);
|
||||
void ClearDeviceContexts();
|
||||
|
||||
private:
|
||||
|
|
|
@ -27,16 +27,31 @@
|
|||
#include "runtime/device/gpu/gpu_buffer_mgr.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "runtime/device/gpu/gpu_common.h"
|
||||
#include "runtime/hardware/gpu/optimizer.h"
|
||||
#include "common/trans.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
bool GPUDeviceContext::Initialize() {
|
||||
if (initialized_ == true) {
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(UintToInt(device_context_key_.device_id_)),
|
||||
"Failed to set device id");
|
||||
GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Set device id
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
bool collective_inited = CollectiveInitializer::instance().collective_inited();
|
||||
if (collective_inited && collective_handle_ != nullptr) {
|
||||
auto get_local_rank_funcptr =
|
||||
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
|
||||
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
|
||||
device_context_key_.device_id_ = IntToUint((*get_local_rank_funcptr)());
|
||||
}
|
||||
|
||||
// Set device id and initialize device resource.
|
||||
bool ret = InitDevice();
|
||||
if (!ret) {
|
||||
|
@ -50,8 +65,6 @@ bool GPUDeviceContext::Initialize() {
|
|||
mem_manager_->MallocDeviceMemory();
|
||||
|
||||
// Initialize NCCL.
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
bool collective_inited = CollectiveInitializer::instance().collective_inited();
|
||||
if (collective_inited && collective_handle_ != nullptr) {
|
||||
auto init_nccl_comm_funcptr =
|
||||
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
|
||||
|
@ -152,6 +165,97 @@ bool GPUDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddress
|
|||
return true;
|
||||
}
|
||||
|
||||
void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// Operator fusion optimization.
|
||||
FuseOperators(graph);
|
||||
|
||||
device::gpu::AssignGpuStream(graph);
|
||||
|
||||
// Update Graph Dynamic Shape Attr.
|
||||
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
const bool pynative_mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
|
||||
// Hide NopOp from execution graph in graph mode
|
||||
if (!pynative_mode) {
|
||||
opt::HideNopNode(graph.get());
|
||||
}
|
||||
}
|
||||
|
||||
void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const {
|
||||
// Graph optimization relevant to device data format
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
||||
pm->AddPass(std::make_shared<opt::PostBatchNormAddReluFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
|
||||
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
|
||||
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
|
||||
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
|
||||
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
|
||||
pm->AddPass(std::make_shared<opt::ReluV2Pass>());
|
||||
pm->AddPass(std::make_shared<opt::AddReluV2Fusion>());
|
||||
pm->AddPass(std::make_shared<opt::AddReluGradV2Fusion>());
|
||||
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
|
||||
pm->AddPass(std::make_shared<opt::GetitemTuple>());
|
||||
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
|
||||
pm->AddPass(std::make_shared<opt::AdamFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
|
||||
}
|
||||
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
|
||||
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
||||
pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce"));
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
graph->SetExecOrderByDefault();
|
||||
|
||||
// Graph kernel fusion optimization
|
||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
return;
|
||||
}
|
||||
opt::GraphKernelOptimize(graph);
|
||||
graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void GPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const {
|
||||
for (const auto &cnode : graph->execution_order()) {
|
||||
if (AnfAlgo::IsNodeDynamicShape(cnode)) {
|
||||
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode);
|
||||
MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
graph->UpdateGraphDynamicAttr();
|
||||
}
|
||||
|
||||
void GPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void GPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {
|
||||
for (const auto &node : nodes) {
|
||||
SetKernelInfo(node);
|
||||
|
|
|
@ -43,6 +43,14 @@ class GPUDeviceContext : public DeviceContext {
|
|||
bool AllocateContinuousMemory(const std::vector<DeviceAddress *> &addr_list, size_t total_size,
|
||||
const std::vector<size_t> &size_list) const override;
|
||||
|
||||
// General graph optimezer ignore device data type and format.
|
||||
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override;
|
||||
// Optimize the kernel graph according to device type, such format transform.
|
||||
void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const override;
|
||||
|
||||
// Optimize the single operator graph for PyNative mode.
|
||||
void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override;
|
||||
|
||||
void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override;
|
||||
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
|
||||
bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,
|
||||
|
@ -54,6 +62,12 @@ class GPUDeviceContext : public DeviceContext {
|
|||
DISABLE_COPY_AND_ASSIGN(GPUDeviceContext);
|
||||
bool InitDevice();
|
||||
|
||||
// Operator fusion optimization.
|
||||
void FuseOperators(const KernelGraphPtr &graph) const;
|
||||
|
||||
// Update Graph Dynamic Shape Attr.
|
||||
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const;
|
||||
|
||||
std::shared_ptr<MemoryManager> mem_manager_;
|
||||
std::vector<void *> streams_;
|
||||
bool initialized_;
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_
|
||||
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/pass_manager.h"
|
||||
#include "backend/optimizer/common/common_backend_optimization.h"
|
||||
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
|
||||
#include "backend/optimizer/gpu/adam_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_weight_fusion.h"
|
||||
#include "backend/optimizer/gpu/batch_norm_relu_fusion.h"
|
||||
#include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h"
|
||||
#include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h"
|
||||
#include "backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h"
|
||||
#include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h"
|
||||
#include "backend/optimizer/gpu/combine_momentum_fusion.h"
|
||||
#include "backend/optimizer/gpu/combine_cast_fusion.h"
|
||||
#include "backend/optimizer/gpu/cudnn_inplace_fusion.h"
|
||||
#include "backend/optimizer/gpu/insert_format_transform_op.h"
|
||||
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_addn_fusion.h"
|
||||
#include "backend/optimizer/gpu/print_reduce_fusion.h"
|
||||
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
|
||||
#include "backend/optimizer/gpu/remove_redundant_format_transform.h"
|
||||
#include "backend/optimizer/gpu/reduce_precision_fusion.h"
|
||||
#include "backend/optimizer/gpu/relu_v2_pass.h"
|
||||
#include "backend/optimizer/gpu/add_relu_v2_fusion.h"
|
||||
#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
|
||||
#include "backend/optimizer/pass/communication_op_fusion.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_
|
Loading…
Reference in New Issue