add graph compiler

This commit is contained in:
lizhenyu 2021-03-24 11:46:26 +08:00
parent d2ee376b7b
commit c57a9edb9d
11 changed files with 354 additions and 22 deletions

View File

@ -128,6 +128,13 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
// Get graph by graph id, if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id) const;
void ClearGraph();
// create a single run op graph
std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask, bool is_ascend = false);
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors);
void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const;
#ifdef ENABLE_DEBUGGER
// set debugger
void SetDebugger() {
@ -163,12 +170,12 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node);
virtual void UnifyMindIR(const KernelGraphPtr &graph) = 0;
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
virtual void UnifyMindIR(const KernelGraphPtr &graph) {}
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; }
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
virtual void BuildGraphImpl(GraphId) {}
virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) = 0;
virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
}
virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {}
@ -183,7 +190,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const;
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors);
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
const std::vector<tensor::TensorPtr> &input_tensors) const;
void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const;
@ -191,10 +197,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
// create graph output for RunOp
void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph);
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
// create a single run op graph
std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask, bool is_ascend = false);
// Generate graph info for a single op graph
GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors);
void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info);
@ -219,8 +221,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const;
virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; }
void InitAllBucket(const KernelGraphPtr &graph);
void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);

View File

@ -0,0 +1,110 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License"){}
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/graph_compiler.h"
#include "runtime/framework/graph_scheduler.h"
namespace mindspore {
namespace runtime {
void GraphCompiler::set_device_context(device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_context);
device_context_ = device_context;
// The member variable 'session_' will be removed after removing session module.
if (session_ == nullptr) {
session_ = std::make_shared<session::SessionBasic>();
}
}
GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs) {
MS_EXCEPTION_IF_NULL(session_);
// Generate kernel graph.
auto graph = session_->ConstructKernelGraph(nodes, outputs);
MS_EXCEPTION_IF_NULL(graph);
return CompileGraphImpl(graph);
}
GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context_);
// Optimization pass which is irrelevant to device type or format.
device_context_->OptimizeGraphWithoutDeviceInfo(graph);
device_context_->SetOperatorInfo(graph->execution_order());
// Optimization pass which is relevant to device type or format.
device_context_->OptimizeGraphWithDeviceInfo(graph);
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
// 'KernelMod' is real executive object of kernel.
device_context_->CreateKernel(graph->execution_order());
// Transform graph to actor DAG, contains build and link.
GraphScheduler::GetInstance().Transform(graph, device_context_);
return graph->graph_id();
}
void GraphCompiler::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(session_);
auto graph = session_->GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(graph);
auto actor_set = GraphScheduler::GetInstance().Fetch(graph);
MS_EXCEPTION_IF_NULL(actor_set);
GraphScheduler::GetInstance().Run(actor_set);
}
void GraphCompiler::CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors,
const std::vector<int64_t> &tensors_mask, VectorRef *outputs) {
// Check if the graph cache exists.
if (run_op_graphs_.find(graph_info) == run_op_graphs_.end()) {
// Prepare the graph
MS_EXCEPTION_IF_NULL(session_);
auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context_);
device_context_->SetOperatorInfo(graph->execution_order());
device_context_->OptimizeSingleOpGraph(graph);
MS_EXCEPTION_IF_NULL(session_);
session_->RunOpHideNopNode(graph);
device_context_->CreateKernel(graph->execution_order());
run_op_graphs_[graph_info] = graph;
}
session_->EraseValueNodeTensor(tensors_mask, input_tensors);
// wait for allreduce
for (auto &tensor : *input_tensors) {
if (tensor->NeedWaitDevice()) {
tensor->WaitDevice();
}
}
// run op
auto graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(graph);
session_->RunOpRemoveNopNode(graph);
GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep);
auto actor_set = GraphScheduler::GetInstance().Fetch(graph);
MS_EXCEPTION_IF_NULL(actor_set);
GraphScheduler::GetInstance().Run(actor_set, GraphExecutionStrategy::kStep);
}
} // namespace runtime
} // namespace mindspore

View File

@ -45,7 +45,7 @@ class GraphCompiler {
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
// Construct single op kernel graph, compile and run the kernel graph in PyNative mode.
void CompileAndRunGraph(OpRunInfo *op_run_info, const GraphInfo &graph_info,
void CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors, const std::vector<int64_t> &tensors_mask,
VectorRef *outputs);
@ -61,7 +61,7 @@ class GraphCompiler {
device::DeviceContext *device_context_{nullptr};
// Single op kernel graph cache for PyNative mode.
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
std::unordered_map<GraphInfo, KernelGraphPtr> run_op_graphs_;
// The member variable 'session_' will be removed after removing session module.
session::SessionPtr session_{nullptr};

View File

@ -21,6 +21,11 @@
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "utils/trace_base.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
#include "backend/optimizer/cpu/insert_cast_cpu.h"
#include "backend/optimizer/pass/replace_node_by_proxy.h"
#include "backend/optimizer/pass/erase_visit_attr.h"
namespace mindspore {
namespace device {
@ -45,6 +50,40 @@ void CPUDeviceContext::FreeMemory(DeviceAddress *const &address) const {
address->ptr_ = nullptr;
}
void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {
// Update Graph Dynamic Shape Attr.
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
OptimizeGraphImpl(graph);
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
auto execution_order = graph->execution_order();
AnfAlgo::ReorderPosteriorExecList(NOT_NULL(&execution_order));
graph->set_execution_order(execution_order);
}
void CPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const { OptimizeGraphImpl(graph); }
void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::InsertCastCPU>());
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(graph);
graph->SetExecOrderByDefault();
}
void CPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const {
for (const auto &cnode : graph->execution_order()) {
if (AnfAlgo::IsNodeDynamicShape(cnode)) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode);
MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
}
}
graph->UpdateGraphDynamicAttr();
}
void CPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {
for (const auto &node : nodes) {
SetKernelInfo(node);

View File

@ -36,15 +36,23 @@ class CPUDeviceContext : public DeviceContext {
bool AllocateMemory(DeviceAddress *const &address, size_t size) const override;
void FreeMemory(DeviceAddress *const &address) const override;
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override;
void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override;
void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override;
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override;
private:
DISABLE_COPY_AND_ASSIGN(CPUDeviceContext);
// Update Graph Dynamic Shape Attr.
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const;
void OptimizeGraphImpl(const KernelGraphPtr &graph) const;
uint32_t device_id_;
std::shared_ptr<MemoryManager> mem_manager_;
bool initialized_;
};

View File

@ -63,17 +63,23 @@ class DeviceContext {
return true;
}
// Optimize the kernel graph according to different devices.
virtual void OptimizeGraph(const KernelGraphPtr &graph) const {}
// The two functions below will be merged to one in the future.
// General graph optimezer ignore device data type and format.
virtual void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {}
// Optimize the kernel graph according to device data type and format.
virtual void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const {}
// Optimize the single operator graph for PyNative mode.
virtual void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {}
// Select the matching backend kernels according to the data type and format of input and output for all
// execution operators, and set final device data type and format information for backend kernels, device
// data type and format which replace original data type and format will use for executing kernels.
virtual void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {}
virtual void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const = 0;
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
// 'KernelMod' is real executive object of kernel.
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {}
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const = 0;
// Launch a kernel via 'KernelMod' of the kernel.
virtual bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,

View File

@ -34,7 +34,7 @@ void DeviceContextManager::ClearDeviceContexts() {
device_contexts_.clear();
}
DeviceContext *DeviceContextManager::GetDeviceContext(const DeviceContextKey &device_context_key) {
DeviceContext *DeviceContextManager::CreateOrGetDeviceContext(const DeviceContextKey &device_context_key) {
std::string device_context_key_str = device_context_key.ToString();
std::lock_guard<std::mutex> guard(lock_);

View File

@ -36,7 +36,7 @@ class DeviceContextManager {
return instance;
}
void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator);
DeviceContext *GetDeviceContext(const DeviceContextKey &device_info);
DeviceContext *CreateOrGetDeviceContext(const DeviceContextKey &device_context_key);
void ClearDeviceContexts();
private:

View File

@ -27,16 +27,31 @@
#include "runtime/device/gpu/gpu_buffer_mgr.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/gpu/gpu_common.h"
#include "runtime/hardware/gpu/optimizer.h"
#include "common/trans.h"
#include "utils/context/graph_kernel_flags.h"
namespace mindspore {
namespace device {
namespace gpu {
bool GPUDeviceContext::Initialize() {
if (initialized_ == true) {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(UintToInt(device_context_key_.device_id_)),
"Failed to set device id");
GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory();
return true;
}
// Set device id
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
bool collective_inited = CollectiveInitializer::instance().collective_inited();
if (collective_inited && collective_handle_ != nullptr) {
auto get_local_rank_funcptr =
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
device_context_key_.device_id_ = IntToUint((*get_local_rank_funcptr)());
}
// Set device id and initialize device resource.
bool ret = InitDevice();
if (!ret) {
@ -50,8 +65,6 @@ bool GPUDeviceContext::Initialize() {
mem_manager_->MallocDeviceMemory();
// Initialize NCCL.
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
bool collective_inited = CollectiveInitializer::instance().collective_inited();
if (collective_inited && collective_handle_ != nullptr) {
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
@ -152,6 +165,97 @@ bool GPUDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddress
return true;
}
void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
// Operator fusion optimization.
FuseOperators(graph);
device::gpu::AssignGpuStream(graph);
// Update Graph Dynamic Shape Attr.
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
const bool pynative_mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
// Hide NopOp from execution graph in graph mode
if (!pynative_mode) {
opt::HideNopNode(graph.get());
}
}
void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const {
// Graph optimization relevant to device data format
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
pm->AddPass(std::make_shared<opt::PostBatchNormAddReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
pm->AddPass(std::make_shared<opt::ReluV2Pass>());
pm->AddPass(std::make_shared<opt::AddReluV2Fusion>());
pm->AddPass(std::make_shared<opt::AddReluGradV2Fusion>());
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
pm->AddPass(std::make_shared<opt::GetitemTuple>());
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(graph);
graph->SetExecOrderByDefault();
}
void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
pm->AddPass(std::make_shared<opt::AdamFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
}
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce"));
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(graph);
graph->SetExecOrderByDefault();
// Graph kernel fusion optimization
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
return;
}
opt::GraphKernelOptimize(graph);
graph->SetExecOrderByDefault();
}
void GPUDeviceContext::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const {
for (const auto &cnode : graph->execution_order()) {
if (AnfAlgo::IsNodeDynamicShape(cnode)) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode);
MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
}
}
graph->UpdateGraphDynamicAttr();
}
void GPUDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(graph);
graph->SetExecOrderByDefault();
}
void GPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {
for (const auto &node : nodes) {
SetKernelInfo(node);

View File

@ -43,6 +43,14 @@ class GPUDeviceContext : public DeviceContext {
bool AllocateContinuousMemory(const std::vector<DeviceAddress *> &addr_list, size_t total_size,
const std::vector<size_t> &size_list) const override;
// General graph optimezer ignore device data type and format.
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override;
// Optimize the kernel graph according to device type, such format transform.
void OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const override;
// Optimize the single operator graph for PyNative mode.
void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override;
void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override;
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,
@ -54,6 +62,12 @@ class GPUDeviceContext : public DeviceContext {
DISABLE_COPY_AND_ASSIGN(GPUDeviceContext);
bool InitDevice();
// Operator fusion optimization.
void FuseOperators(const KernelGraphPtr &graph) const;
// Update Graph Dynamic Shape Attr.
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &graph) const;
std::shared_ptr<MemoryManager> mem_manager_;
std::vector<void *> streams_;
bool initialized_;

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
#include "backend/optimizer/common/common_backend_optimization.h"
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
#include "backend/optimizer/gpu/adam_fusion.h"
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
#include "backend/optimizer/gpu/apply_momentum_weight_fusion.h"
#include "backend/optimizer/gpu/batch_norm_relu_fusion.h"
#include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h"
#include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h"
#include "backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h"
#include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h"
#include "backend/optimizer/gpu/combine_momentum_fusion.h"
#include "backend/optimizer/gpu/combine_cast_fusion.h"
#include "backend/optimizer/gpu/cudnn_inplace_fusion.h"
#include "backend/optimizer/gpu/insert_format_transform_op.h"
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
#include "backend/optimizer/gpu/replace_addn_fusion.h"
#include "backend/optimizer/gpu/print_reduce_fusion.h"
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
#include "backend/optimizer/gpu/remove_redundant_format_transform.h"
#include "backend/optimizer/gpu/reduce_precision_fusion.h"
#include "backend/optimizer/gpu/relu_v2_pass.h"
#include "backend/optimizer/gpu/add_relu_v2_fusion.h"
#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
#include "backend/optimizer/pass/communication_op_fusion.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_