From 140174182d190d2e6c2a12a1416d6495403b0713 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Wed, 15 Jul 2020 16:28:45 +0800 Subject: [PATCH] gpu add fusion: replace momentum cast --- .../gpu/cuda_impl/momentum_impl.cu | 27 +++++--- .../gpu/cuda_impl/momentum_impl.cuh | 4 +- .../kernel_compiler/gpu/gpu_kernel_factory.h | 6 ++ .../gpu/nn/fused_batch_norm_gpu_kernel.cc | 32 +++++----- .../gpu/nn/fused_batch_norm_gpu_kernel.h | 12 ++-- .../gpu/nn/fused_batchnorm_grad_gpu_kernel.cc | 10 +-- .../gpu/nn/fused_batchnorm_grad_gpu_kernel.h | 10 +-- .../gpu/nn/momentum_gpu_kernel.cc | 63 +++++++++++-------- .../gpu/nn/momentum_gpu_kernel.h | 6 +- .../gpu/replace_momentum_cast_fusion.cc | 63 +++++++++++++++++++ .../gpu/replace_momentum_cast_fusion.h | 46 ++++++++++++++ .../ccsrc/backend/session/gpu_session.cc | 10 +++ 12 files changed, 218 insertions(+), 71 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc create mode 100644 mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu index 5a1c9eb6871..03a4ccb6178 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu @@ -15,9 +15,9 @@ */ #include "momentum_impl.cuh" -template +template __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate, - const T *gradient, const S *momentum) { + const G *gradient, const S *momentum) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; variable[i] -= learning_rate[0] * accumulation[i]; @@ -34,19 +34,32 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, } return; } -template -void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, +template <> +__global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation, + const float *learning_rate, const half *gradient, + const float *momentum) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]); + variable[i] -= learning_rate[0] * accumulation[i]; + } + return; +} +template +void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, const S *momentum, cudaStream_t cuda_stream) { MomentumUpdateVariableKernel<<>>(size, variable, accumulation, learning_rate, gradient, momentum); return; } -template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, +template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, const float *learning_rate, const float *gradient, const float *momentum, cudaStream_t cuda_stream); -template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, +template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, const half *learning_rate, const half *gradient, const half *momentum, cudaStream_t cuda_stream); -template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, +template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, + const float *learning_rate, const half *gradient, + const float *momentum, cudaStream_t cuda_stream); +template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, const float *learning_rate, const half *gradient, const float *momentum, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh index 62708663ad3..e5a22e47913 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh @@ -18,8 +18,8 @@ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ #include "runtime/device/gpu/cuda_common.h" -template -void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, +template +void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient, const S *momentum, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h index 8834fa0f1a6..7af82442030 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h @@ -88,6 +88,12 @@ class GpuKernelRegister { static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \ []() { return new OPCLASS(); }); + +// register of mixed accuracy kernels which use template and maintain three typename +#define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \ + static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg( \ + #OPNAME, ATTR, []() { return new OPCLASS(); }); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc index 2ce39b63a02..ddd9c2f8d05 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc @@ -34,15 +34,15 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, KernelAttr() .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), FusedBatchNormGpuKernel, half) MS_REG_GPU_KERNEL_ONE(BatchNorm, KernelAttr() @@ -60,15 +60,15 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, MS_REG_GPU_KERNEL_ONE(BatchNorm, KernelAttr() .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), FusedBatchNormGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h index 774428dc409..fcf21b02142 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h @@ -56,17 +56,17 @@ class FusedBatchNormGpuKernel : public GpuKernel { return true; } auto x = GetDeviceAddress(inputs, 0); - auto scale = GetDeviceAddress(inputs, 1); - auto bias = GetDeviceAddress(inputs, 2); - auto runing_mean = GetDeviceAddress(inputs, 3); - auto runnig_variance = GetDeviceAddress(inputs, 4); + auto scale = GetDeviceAddress(inputs, 1); + auto bias = GetDeviceAddress(inputs, 2); + auto runing_mean = GetDeviceAddress(inputs, 3); + auto runnig_variance = GetDeviceAddress(inputs, 4); auto y = GetDeviceAddress(outputs, 0); const float alpha = 1; const float beta = 0; if (is_train_) { - auto save_mean = GetDeviceAddress(outputs, 3); - auto save_variance = GetDeviceAddress(outputs, 4); + auto save_mean = GetDeviceAddress(outputs, 3); + auto save_variance = GetDeviceAddress(outputs, 4); CHECK_CUDNN_RET_WITH_EXCEPT( cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc index 546e034f6bf..7cd993d0f0d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc @@ -33,12 +33,12 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), FusedBatchNormGradGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h index a2d0d741b13..e2da67b85ef 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h @@ -55,12 +55,12 @@ class FusedBatchNormGradGpuKernel : public GpuKernel { } auto dy = GetDeviceAddress(inputs, 0); auto x = GetDeviceAddress(inputs, 1); - auto scale = GetDeviceAddress(inputs, 2); - auto save_mean = GetDeviceAddress(inputs, 3); - auto save_variance = GetDeviceAddress(inputs, 4); + auto scale = GetDeviceAddress(inputs, 2); + auto save_mean = GetDeviceAddress(inputs, 3); + auto save_variance = GetDeviceAddress(inputs, 4); auto dx = GetDeviceAddress(outputs, 0); - auto bn_scale = GetDeviceAddress(outputs, 1); - auto bn_bias = GetDeviceAddress(outputs, 2); + auto bn_scale = GetDeviceAddress(outputs, 1); + auto bn_bias = GetDeviceAddress(outputs, 2); const float alpha_data_diff = 1; const float beta_data_diff = 0; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc index 99ae2affe8c..96411e9bbf9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc @@ -18,32 +18,41 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - MomentumGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - MomentumGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16), - MomentumGpuKernel, half, float) +MS_REG_GPU_KERNEL_THREE(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + MomentumGpuKernel, float, float, float) +MS_REG_GPU_KERNEL_THREE(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + MomentumGpuKernel, half, half, half) +MS_REG_GPU_KERNEL_THREE(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16), + MomentumGpuKernel, half, float, half) +MS_REG_GPU_KERNEL_THREE(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + MomentumGpuKernel, float, float, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h index 32d3fbb079b..02b897f679e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h @@ -23,7 +23,7 @@ #include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" namespace mindspore { namespace kernel { -template +template class MomentumGpuKernel : public GpuKernel { public: MomentumGpuKernel() @@ -38,7 +38,7 @@ class MomentumGpuKernel : public GpuKernel { T *variable = GetDeviceAddress(inputs, 0); T *accumulation = GetDeviceAddress(inputs, 1); S *learning_rate = GetDeviceAddress(inputs, 2); - T *gradient = GetDeviceAddress(inputs, 3); + G *gradient = GetDeviceAddress(inputs, 3); S *momentum = GetDeviceAddress(inputs, 4); MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, reinterpret_cast(stream_ptr)); @@ -54,7 +54,7 @@ class MomentumGpuKernel : public GpuKernel { variable_size_ = sizeof(T); accumulation_size_ = sizeof(T); learning_rate_size_ = sizeof(S); - gradient_size_ = sizeof(T); + gradient_size_ = sizeof(G); momentum_size_ = sizeof(S); auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc new file mode 100644 index 00000000000..864bb026af3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ReplaceMomentumCastFusion::DefinePattern() const { + VectorRef grad_cast = VectorRef({prim::kPrimCast, grad_}); + VectorRef momentum = VectorRef({prim::kPrimApplyMomentum, var_, acc_, lr_, grad_cast, mom_}); + return momentum; +} + +const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + + auto grad_cast = AnfAlgo::GetInputNode(utils::cast(node), 3); + auto grad = AnfAlgo::GetInputNode(utils::cast(grad_cast), 0); + MS_EXCEPTION_IF_NULL(grad_cast); + MS_EXCEPTION_IF_NULL(grad); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(utils::cast(grad_cast), utils::cast(grad)); + std::vector outputs_type; + std::vector> outputs_shape; + auto output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t i = 0; i < output_num; i++) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, i)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(node, i)); + } + outputs_type[3] = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0); + + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, node.get()); + + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.h new file mode 100644 index 00000000000..f67033dcbee --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ReplaceMomentumCastFusion : public PatternProcessPass { + public: + explicit ReplaceMomentumCastFusion(bool multigraph = true) : PatternProcessPass("replace_momentum_cast", multigraph) { + var_ = std::make_shared(); + acc_ = std::make_shared(); + lr_ = std::make_shared(); + grad_ = std::make_shared(); + mom_ = std::make_shared(); + } + ~ReplaceMomentumCastFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr var_; + VarPtr acc_; + VarPtr lr_; + VarPtr grad_; + VarPtr mom_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 14e30c1a443..e289cffa56b 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -25,6 +25,11 @@ #include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" #include "backend/optimizer/gpu/adam_fusion.h" +#include "backend/optimizer/gpu/replace_bn_cast_fusion.h" +#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" +#include "backend/optimizer/gpu/replace_bn_grad_cast2_fusion.h" +#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" +#include "backend/optimizer/gpu/replace_addn_fusion.h" #include "runtime/device/kernel_runtime_manager.h" #include "predict/predict.h" #include "common/utils.h" @@ -59,6 +64,11 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { auto pm = std::make_shared(); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault();