forked from mindspore-Ecosystem/mindspore
!3092 GPU add fuison: replace momentum cast
Merge pull request !3092 from VectorSL/momentum
This commit is contained in:
commit
ae50c37c38
|
@ -15,9 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "momentum_impl.cuh"
|
||||
template <typename T, typename S>
|
||||
template <typename T, typename S, typename G>
|
||||
__global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate,
|
||||
const T *gradient, const S *momentum) {
|
||||
const G *gradient, const S *momentum) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
|
||||
accumulation[i] = momentum[0] * accumulation[i] + gradient[i];
|
||||
variable[i] -= learning_rate[0] * accumulation[i];
|
||||
|
@ -34,19 +34,32 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable,
|
|||
}
|
||||
return;
|
||||
}
|
||||
template <typename T, typename S>
|
||||
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient,
|
||||
template <>
|
||||
__global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation,
|
||||
const float *learning_rate, const half *gradient,
|
||||
const float *momentum) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
|
||||
accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]);
|
||||
variable[i] -= learning_rate[0] * accumulation[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T, typename S, typename G>
|
||||
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient,
|
||||
const S *momentum, cudaStream_t cuda_stream) {
|
||||
MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, accumulation,
|
||||
learning_rate, gradient, momentum);
|
||||
return;
|
||||
}
|
||||
template void MomentumUpdateVariable<float, float>(const size_t size, float *variable, float *accumulation,
|
||||
template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation,
|
||||
const float *learning_rate, const float *gradient,
|
||||
const float *momentum, cudaStream_t cuda_stream);
|
||||
template void MomentumUpdateVariable<half, half>(const size_t size, half *variable, half *accumulation,
|
||||
template void MomentumUpdateVariable<half, half, half>(const size_t size, half *variable, half *accumulation,
|
||||
const half *learning_rate, const half *gradient,
|
||||
const half *momentum, cudaStream_t cuda_stream);
|
||||
template void MomentumUpdateVariable<half, float>(const size_t size, half *variable, half *accumulation,
|
||||
template void MomentumUpdateVariable<half, float, half>(const size_t size, half *variable, half *accumulation,
|
||||
const float *learning_rate, const half *gradient,
|
||||
const float *momentum, cudaStream_t cuda_stream);
|
||||
template void MomentumUpdateVariable<float, float, half>(const size_t size, float *variable, float *accumulation,
|
||||
const float *learning_rate, const half *gradient,
|
||||
const float *momentum, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T, typename S>
|
||||
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient,
|
||||
template <typename T, typename S, typename G>
|
||||
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient,
|
||||
const S *momentum, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
|
||||
|
|
|
@ -88,6 +88,12 @@ class GpuKernelRegister {
|
|||
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \
|
||||
static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \
|
||||
[]() { return new OPCLASS<T, S>(); });
|
||||
|
||||
// register of mixed accuracy kernels which use template and maintain three typename
|
||||
#define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \
|
||||
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S, G>>::value, " must be base of GpuKernel"); \
|
||||
static const GpuKernelRegister g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg( \
|
||||
#OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_
|
||||
|
|
|
@ -34,15 +34,15 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNorm,
|
|||
MS_REG_GPU_KERNEL_ONE(FusedBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNorm,
|
||||
KernelAttr()
|
||||
|
@ -60,15 +60,15 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm,
|
|||
MS_REG_GPU_KERNEL_ONE(BatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,17 +56,17 @@ class FusedBatchNormGpuKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
auto x = GetDeviceAddress<T>(inputs, 0);
|
||||
auto scale = GetDeviceAddress<T>(inputs, 1);
|
||||
auto bias = GetDeviceAddress<T>(inputs, 2);
|
||||
auto runing_mean = GetDeviceAddress<T>(inputs, 3);
|
||||
auto runnig_variance = GetDeviceAddress<T>(inputs, 4);
|
||||
auto scale = GetDeviceAddress<float>(inputs, 1);
|
||||
auto bias = GetDeviceAddress<float>(inputs, 2);
|
||||
auto runing_mean = GetDeviceAddress<float>(inputs, 3);
|
||||
auto runnig_variance = GetDeviceAddress<float>(inputs, 4);
|
||||
auto y = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
if (is_train_) {
|
||||
auto save_mean = GetDeviceAddress<T>(outputs, 3);
|
||||
auto save_variance = GetDeviceAddress<T>(outputs, 4);
|
||||
auto save_mean = GetDeviceAddress<float>(outputs, 3);
|
||||
auto save_variance = GetDeviceAddress<float>(outputs, 4);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y,
|
||||
scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean,
|
||||
|
|
|
@ -33,12 +33,12 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad,
|
|||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGradGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,12 +55,12 @@ class FusedBatchNormGradGpuKernel : public GpuKernel {
|
|||
}
|
||||
auto dy = GetDeviceAddress<T>(inputs, 0);
|
||||
auto x = GetDeviceAddress<T>(inputs, 1);
|
||||
auto scale = GetDeviceAddress<T>(inputs, 2);
|
||||
auto save_mean = GetDeviceAddress<T>(inputs, 3);
|
||||
auto save_variance = GetDeviceAddress<T>(inputs, 4);
|
||||
auto scale = GetDeviceAddress<float>(inputs, 2);
|
||||
auto save_mean = GetDeviceAddress<float>(inputs, 3);
|
||||
auto save_variance = GetDeviceAddress<float>(inputs, 4);
|
||||
auto dx = GetDeviceAddress<T>(outputs, 0);
|
||||
auto bn_scale = GetDeviceAddress<T>(outputs, 1);
|
||||
auto bn_bias = GetDeviceAddress<T>(outputs, 2);
|
||||
auto bn_scale = GetDeviceAddress<float>(outputs, 1);
|
||||
auto bn_bias = GetDeviceAddress<float>(outputs, 2);
|
||||
|
||||
const float alpha_data_diff = 1;
|
||||
const float beta_data_diff = 0;
|
||||
|
|
|
@ -18,32 +18,41 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(ApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MomentumGpuKernel, float, float)
|
||||
MS_REG_GPU_KERNEL_TWO(ApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MomentumGpuKernel, half, half)
|
||||
MS_REG_GPU_KERNEL_TWO(ApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MomentumGpuKernel, half, float)
|
||||
MS_REG_GPU_KERNEL_THREE(ApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MomentumGpuKernel, float, float, float)
|
||||
MS_REG_GPU_KERNEL_THREE(ApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MomentumGpuKernel, half, half, half)
|
||||
MS_REG_GPU_KERNEL_THREE(ApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MomentumGpuKernel, half, float, half)
|
||||
MS_REG_GPU_KERNEL_THREE(ApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MomentumGpuKernel, float, float, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
template <typename T, typename S, typename G>
|
||||
class MomentumGpuKernel : public GpuKernel {
|
||||
public:
|
||||
MomentumGpuKernel()
|
||||
|
@ -38,7 +38,7 @@ class MomentumGpuKernel : public GpuKernel {
|
|||
T *variable = GetDeviceAddress<T>(inputs, 0);
|
||||
T *accumulation = GetDeviceAddress<T>(inputs, 1);
|
||||
S *learning_rate = GetDeviceAddress<S>(inputs, 2);
|
||||
T *gradient = GetDeviceAddress<T>(inputs, 3);
|
||||
G *gradient = GetDeviceAddress<G>(inputs, 3);
|
||||
S *momentum = GetDeviceAddress<S>(inputs, 4);
|
||||
MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
@ -54,7 +54,7 @@ class MomentumGpuKernel : public GpuKernel {
|
|||
variable_size_ = sizeof(T);
|
||||
accumulation_size_ = sizeof(T);
|
||||
learning_rate_size_ = sizeof(S);
|
||||
gradient_size_ = sizeof(T);
|
||||
gradient_size_ = sizeof(G);
|
||||
momentum_size_ = sizeof(S);
|
||||
|
||||
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef ReplaceMomentumCastFusion::DefinePattern() const {
|
||||
VectorRef grad_cast = VectorRef({prim::kPrimCast, grad_});
|
||||
VectorRef momentum = VectorRef({prim::kPrimApplyMomentum, var_, acc_, lr_, grad_cast, mom_});
|
||||
return momentum;
|
||||
}
|
||||
|
||||
const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3);
|
||||
auto grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(grad_cast), 0);
|
||||
MS_EXCEPTION_IF_NULL(grad_cast);
|
||||
MS_EXCEPTION_IF_NULL(grad);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(utils::cast<CNodePtr>(grad_cast), utils::cast<CNodePtr>(grad));
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, i));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(node, i));
|
||||
}
|
||||
outputs_type[3] = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
|
||||
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, node.get());
|
||||
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ReplaceMomentumCastFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit ReplaceMomentumCastFusion(bool multigraph = true) : PatternProcessPass("replace_momentum_cast", multigraph) {
|
||||
var_ = std::make_shared<Var>();
|
||||
acc_ = std::make_shared<Var>();
|
||||
lr_ = std::make_shared<Var>();
|
||||
grad_ = std::make_shared<Var>();
|
||||
mom_ = std::make_shared<Var>();
|
||||
}
|
||||
~ReplaceMomentumCastFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr var_;
|
||||
VarPtr acc_;
|
||||
VarPtr lr_;
|
||||
VarPtr grad_;
|
||||
VarPtr mom_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_
|
|
@ -25,6 +25,11 @@
|
|||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
|
||||
#include "backend/optimizer/gpu/adam_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_bn_cast_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_bn_grad_cast2_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_addn_fusion.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "predict/predict.h"
|
||||
#include "common/utils.h"
|
||||
|
@ -59,6 +64,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
|
||||
pm->AddPass(std::make_shared<opt::AdamFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceBNCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceBNGradCast2Fusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
Loading…
Reference in New Issue