forked from mindspore-Ecosystem/mindspore
!5612 Resnet50 pattern Fusion
Merge pull request !5612 from chenweifeng/BatchNormAddReluGrad
This commit is contained in:
commit
1944b8e53b
77
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu
Executable file → Normal file
77
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu
Executable file → Normal file
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,8 +26,7 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *
|
|||
}
|
||||
template <>
|
||||
__global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, half *accumulation,
|
||||
const float *learning_rate, const half *gradient,
|
||||
const float *momentum) {
|
||||
const float *learning_rate, const half *gradient, const float *momentum) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
|
||||
accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i];
|
||||
variable[i] -= __float2half(learning_rate[0]) * accumulation[i];
|
||||
|
@ -36,8 +35,7 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable,
|
|||
}
|
||||
template <>
|
||||
__global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation,
|
||||
const float *learning_rate, const half *gradient,
|
||||
const float *momentum) {
|
||||
const float *learning_rate, const half *gradient, const float *momentum) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
|
||||
accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]);
|
||||
variable[i] -= learning_rate[0] * accumulation[i];
|
||||
|
@ -51,15 +49,68 @@ void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, con
|
|||
learning_rate, gradient, momentum);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void FusedMomentumWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable,
|
||||
T *accumulation, const T *learning_rate, const S *gradient,
|
||||
const T *momentum) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) {
|
||||
T grad = (variable[i] * weight_decay[0] + static_cast<T>(gradient[i])) * scale[0];
|
||||
accumulation[i] = momentum[0] * accumulation[i] + grad;
|
||||
variable[i] -= learning_rate[0] * accumulation[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, T *accumulation,
|
||||
const T *learning_rate, const S *gradient, const T *momentum,
|
||||
cudaStream_t cuda_stream) {
|
||||
size_t thread_per_block = 256;
|
||||
size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block;
|
||||
FusedMomentumWeightDecayScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
|
||||
element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation,
|
||||
const T *learning_rate, const S *gradient, const T *momentum) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) {
|
||||
accumulation[i] = momentum[0] * accumulation[i] + static_cast<T>(gradient[i]);
|
||||
variable[i] -= learning_rate[0] * accumulation[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate,
|
||||
const S *gradient, const T *momentum, cudaStream_t cuda_stream) {
|
||||
size_t thread_per_block = 256;
|
||||
size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block;
|
||||
FusedMomentumScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
|
||||
element_num, scale, variable, accumulation, learning_rate, gradient, momentum);
|
||||
}
|
||||
|
||||
template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation,
|
||||
const float *learning_rate, const float *gradient,
|
||||
const float *momentum, cudaStream_t cuda_stream);
|
||||
const float *learning_rate, const float *gradient,
|
||||
const float *momentum, cudaStream_t cuda_stream);
|
||||
template void MomentumUpdateVariable<half, half, half>(const size_t size, half *variable, half *accumulation,
|
||||
const half *learning_rate, const half *gradient,
|
||||
const half *momentum, cudaStream_t cuda_stream);
|
||||
const half *learning_rate, const half *gradient,
|
||||
const half *momentum, cudaStream_t cuda_stream);
|
||||
template void MomentumUpdateVariable<half, float, half>(const size_t size, half *variable, half *accumulation,
|
||||
const float *learning_rate, const half *gradient,
|
||||
const float *momentum, cudaStream_t cuda_stream);
|
||||
const float *learning_rate, const half *gradient,
|
||||
const float *momentum, cudaStream_t cuda_stream);
|
||||
template void MomentumUpdateVariable<float, float, half>(const size_t size, float *variable, float *accumulation,
|
||||
const float *learning_rate, const half *gradient,
|
||||
const float *momentum, cudaStream_t cuda_stream);
|
||||
const float *learning_rate, const half *gradient,
|
||||
const float *momentum, cudaStream_t cuda_stream);
|
||||
|
||||
template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale,
|
||||
float *variable, float *accumulation, const float *learning_rate,
|
||||
const float *gradient, const float *momentum, cudaStream_t cuda_stream);
|
||||
template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale,
|
||||
float *variable, float *accumulation, const float *learning_rate,
|
||||
const half *gradient, const float *momentum, cudaStream_t cuda_stream);
|
||||
template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation,
|
||||
const float *learning_rate, const float *gradient, const float *momentum,
|
||||
cudaStream_t cuda_stream);
|
||||
template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation,
|
||||
const float *learning_rate, const half *gradient, const float *momentum,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
9
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh
Executable file → Normal file
9
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh
Executable file → Normal file
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,5 +21,12 @@
|
|||
template <typename T, typename S, typename G>
|
||||
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient,
|
||||
const S *momentum, cudaStream_t cuda_stream);
|
||||
template <typename T, typename S>
|
||||
void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, T *accumulation,
|
||||
const T *learning_rate, const S *gradient, const T *momentum,
|
||||
cudaStream_t cuda_stream);
|
||||
template <typename T, typename S>
|
||||
void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate,
|
||||
const S *gradient, const T *momentum, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(FusedScaleApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32) // scale
|
||||
.AddInputAttr(kNumberTypeFloat32) // variable
|
||||
.AddInputAttr(kNumberTypeFloat32) // accumulation
|
||||
.AddInputAttr(kNumberTypeFloat32) // learning_rate
|
||||
.AddInputAttr(kNumberTypeFloat32) // gradient
|
||||
.AddInputAttr(kNumberTypeFloat32) // momentum
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedScaleMomentumGpuKernel, float, float)
|
||||
MS_REG_GPU_KERNEL_TWO(FusedScaleApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32) // scale
|
||||
.AddInputAttr(kNumberTypeFloat32) // variable
|
||||
.AddInputAttr(kNumberTypeFloat32) // accumulation
|
||||
.AddInputAttr(kNumberTypeFloat32) // variable
|
||||
.AddInputAttr(kNumberTypeFloat32) // accumulation
|
||||
.AddInputAttr(kNumberTypeFloat32) // learning_rate
|
||||
.AddInputAttr(kNumberTypeFloat16) // gradient
|
||||
.AddInputAttr(kNumberTypeFloat32) // momentum
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedScaleMomentumGpuKernel, float, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class FusedScaleMomentumGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FusedScaleMomentumGpuKernel() : element_num_(1) {}
|
||||
~FusedScaleMomentumGpuKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
void *stream_ptr) override {
|
||||
T *scale = GetDeviceAddress<T>(inputs, 0);
|
||||
T *variable = GetDeviceAddress<T>(inputs, 1);
|
||||
T *accumulation = GetDeviceAddress<T>(inputs, 2);
|
||||
T *learning_rate = GetDeviceAddress<T>(inputs, 3);
|
||||
S *gradient = GetDeviceAddress<S>(inputs, 4);
|
||||
T *momentum = GetDeviceAddress<T>(inputs, 5);
|
||||
|
||||
FusedScaleMomentum(element_num_, scale, variable, accumulation, learning_rate, gradient, momentum,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 6) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 6 inputs.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < variable_shape.size(); i++) {
|
||||
element_num_ *= variable_shape[i];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(element_num_ * sizeof(T));
|
||||
input_size_list_.push_back(element_num_ * sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(element_num_ * sizeof(S));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
output_size_list_.push_back(element_num_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t element_num_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(FusedWeightScaleApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32) // weight decay
|
||||
.AddInputAttr(kNumberTypeFloat32) // scale
|
||||
.AddInputAttr(kNumberTypeFloat32) // variable
|
||||
.AddInputAttr(kNumberTypeFloat32) // accumulation
|
||||
.AddInputAttr(kNumberTypeFloat32) // learning_rate
|
||||
.AddInputAttr(kNumberTypeFloat32) // gradient
|
||||
.AddInputAttr(kNumberTypeFloat32) // momentum
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedWeightDecayScaleMomentumGpuKernel, float, float)
|
||||
MS_REG_GPU_KERNEL_TWO(FusedWeightScaleApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32) // variable
|
||||
.AddInputAttr(kNumberTypeFloat32) // accumulation
|
||||
.AddInputAttr(kNumberTypeFloat32) // variable
|
||||
.AddInputAttr(kNumberTypeFloat32) // accumulation
|
||||
.AddInputAttr(kNumberTypeFloat32) // learning_rate
|
||||
.AddInputAttr(kNumberTypeFloat16) // gradient
|
||||
.AddInputAttr(kNumberTypeFloat32) // momentum
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedWeightDecayScaleMomentumGpuKernel, float, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class FusedWeightDecayScaleMomentumGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FusedWeightDecayScaleMomentumGpuKernel() : element_num_(1) {}
|
||||
~FusedWeightDecayScaleMomentumGpuKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
void *stream_ptr) override {
|
||||
T *weight_decay = GetDeviceAddress<T>(inputs, 0);
|
||||
T *scale = GetDeviceAddress<T>(inputs, 1);
|
||||
T *variable = GetDeviceAddress<T>(inputs, 2);
|
||||
T *accumulation = GetDeviceAddress<T>(inputs, 3);
|
||||
T *learning_rate = GetDeviceAddress<T>(inputs, 4);
|
||||
S *gradient = GetDeviceAddress<S>(inputs, 5);
|
||||
T *momentum = GetDeviceAddress<T>(inputs, 6);
|
||||
|
||||
FusedWeightDecayScaleMomentum(element_num_, weight_decay, scale, variable, accumulation, learning_rate, gradient,
|
||||
momentum, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 7) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 7 inputs.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < variable_shape.size(); i++) {
|
||||
element_num_ *= variable_shape[i];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(element_num_ * sizeof(T));
|
||||
input_size_list_.push_back(element_num_ * sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(element_num_ * sizeof(S));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
output_size_list_.push_back(element_num_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t element_num_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef ApplyMomentumScaleFusion::DefinePattern() const {
|
||||
VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_});
|
||||
VectorRef apply_momentum =
|
||||
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_});
|
||||
return apply_momentum;
|
||||
}
|
||||
|
||||
const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto scale = utils::cast<AnfNodePtr>((*equiv)[scale_]);
|
||||
auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]);
|
||||
auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]);
|
||||
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
|
||||
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
|
||||
MS_EXCEPTION_IF_NULL(scale);
|
||||
MS_EXCEPTION_IF_NULL(variable);
|
||||
MS_EXCEPTION_IF_NULL(accumulation);
|
||||
MS_EXCEPTION_IF_NULL(learning_rate);
|
||||
MS_EXCEPTION_IF_NULL(gradient);
|
||||
MS_EXCEPTION_IF_NULL(momentum);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedScaleApplyMomentum);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), scale, variable, accumulation,
|
||||
learning_rate, gradient, momentum};
|
||||
auto replace_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get());
|
||||
replace_node->set_scope(node->scope());
|
||||
return replace_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ApplyMomentumScaleFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) {
|
||||
scale_ = std::make_shared<Var>();
|
||||
variable_ = std::make_shared<Var>();
|
||||
accumulation_ = std::make_shared<Var>();
|
||||
learning_rate_ = std::make_shared<Var>();
|
||||
gradient_ = std::make_shared<Var>();
|
||||
momentum_ = std::make_shared<Var>();
|
||||
}
|
||||
~ApplyMomentumScaleFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr scale_;
|
||||
VarPtr variable_;
|
||||
VarPtr accumulation_;
|
||||
VarPtr learning_rate_;
|
||||
VarPtr gradient_;
|
||||
VarPtr momentum_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
|
||||
VectorRef weight = VectorRef(
|
||||
{prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});
|
||||
VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_});
|
||||
VectorRef apply_momentum =
|
||||
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_});
|
||||
return apply_momentum;
|
||||
}
|
||||
|
||||
const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto weight_decay = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]);
|
||||
auto scale = utils::cast<AnfNodePtr>((*equiv)[scale_]);
|
||||
auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]);
|
||||
auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]);
|
||||
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
|
||||
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
|
||||
MS_EXCEPTION_IF_NULL(weight_decay);
|
||||
MS_EXCEPTION_IF_NULL(scale);
|
||||
MS_EXCEPTION_IF_NULL(variable);
|
||||
MS_EXCEPTION_IF_NULL(accumulation);
|
||||
MS_EXCEPTION_IF_NULL(learning_rate);
|
||||
MS_EXCEPTION_IF_NULL(gradient);
|
||||
MS_EXCEPTION_IF_NULL(momentum);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedWeightScaleApplyMomentum);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable,
|
||||
accumulation, learning_rate, gradient, momentum};
|
||||
auto replace_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get());
|
||||
replace_node->set_scope(node->scope());
|
||||
return replace_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true)
|
||||
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) {
|
||||
weight_decay_ = std::make_shared<Var>();
|
||||
scale_ = std::make_shared<Var>();
|
||||
variable_ = std::make_shared<Var>();
|
||||
accumulation_ = std::make_shared<Var>();
|
||||
learning_rate_ = std::make_shared<Var>();
|
||||
gradient_ = std::make_shared<Var>();
|
||||
momentum_ = std::make_shared<Var>();
|
||||
}
|
||||
~ApplyMomentumWeightDecayScaleFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr weight_decay_;
|
||||
VarPtr scale_;
|
||||
|
||||
VarPtr variable_;
|
||||
VarPtr accumulation_;
|
||||
VarPtr learning_rate_;
|
||||
VarPtr gradient_;
|
||||
VarPtr momentum_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_
|
|
@ -0,0 +1,175 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
const std::vector<int> kOutputIndex{0, 1, 2};
|
||||
constexpr size_t kBNGradOutputNum = 3;
|
||||
constexpr size_t kBNAddReluGradOutputNum = 4;
|
||||
|
||||
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_outputs);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (manager->node_users().find(bn) == manager->node_users().end()) {
|
||||
return false;
|
||||
}
|
||||
size_t output_num = 0;
|
||||
for (const auto &node_index : manager->node_users()[bn]) {
|
||||
const AnfNodePtr &output = node_index.first;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
auto tuple_getiterm_cnode = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
|
||||
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
auto value_node = index_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
int index = GetValue<int>(value_node->value());
|
||||
if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) {
|
||||
return false;
|
||||
}
|
||||
bn_outputs->push_back(output);
|
||||
output_num++;
|
||||
}
|
||||
return output_num == kBNGradOutputNum;
|
||||
}
|
||||
|
||||
void SetShapeAndType(const CNodePtr &bn_add_relu_grad, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad) {
|
||||
// set output shape and dtype
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(bn_grad);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(bn_grad, i));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(bn_grad, i));
|
||||
}
|
||||
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, 0));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(relu_grad, 0));
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, bn_add_relu_grad.get());
|
||||
}
|
||||
|
||||
void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad,
|
||||
const CNodePtr &bn_add_relu_grad) {
|
||||
// Create outputs
|
||||
std::vector<AnfNodePtr> bn_add_relu_grad_output;
|
||||
CreateMultipleOutputsOfAnfNode(graph, bn_add_relu_grad, kBNAddReluGradOutputNum, &bn_add_relu_grad_output);
|
||||
if (bn_add_relu_grad_output.size() != kBNAddReluGradOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The output size of node " << kFusedBatchNormGradExWithAddAndActivation << " must be "
|
||||
<< kBNAddReluGradOutputNum << ", but it is " << bn_add_relu_grad_output.size();
|
||||
}
|
||||
|
||||
// Get bn outputs
|
||||
std::vector<AnfNodePtr> bn_outputs;
|
||||
if (!GetBatchNormOutputs(graph, bn_grad, &bn_outputs)) {
|
||||
MS_LOG(INFO) << "The " << prim::kPrimFusedBatchNormGradEx
|
||||
<< " node should only have output 0, 1 and 2. The node should not be changed";
|
||||
return;
|
||||
}
|
||||
|
||||
// Replace orignal output
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem);
|
||||
size_t output_index = 0;
|
||||
for (const auto &output : bn_outputs) {
|
||||
(void)manager->Replace(output, bn_add_relu_grad_output[output_index]);
|
||||
output_index++;
|
||||
}
|
||||
|
||||
manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]);
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
|
||||
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
|
||||
VectorRef batch_norm_grad =
|
||||
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
|
||||
return batch_norm_grad;
|
||||
}
|
||||
|
||||
const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(relu_grad);
|
||||
auto relu_users = GetRealNodeUsedList(graph, relu_grad);
|
||||
if (relu_users->size() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// process pattern as Relu(TensorAdd(BN#0, BN#1))
|
||||
auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
|
||||
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0);
|
||||
MS_EXCEPTION_IF_NULL(dy);
|
||||
auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1);
|
||||
MS_EXCEPTION_IF_NULL(y);
|
||||
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 1);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 2);
|
||||
MS_EXCEPTION_IF_NULL(scale);
|
||||
auto save_mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3);
|
||||
MS_EXCEPTION_IF_NULL(save_mean);
|
||||
auto save_var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 4);
|
||||
MS_EXCEPTION_IF_NULL(save_var);
|
||||
auto reserve = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
|
||||
MS_EXCEPTION_IF_NULL(reserve);
|
||||
auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(save_mean), 0);
|
||||
MS_EXCEPTION_IF_NULL(batch_norm);
|
||||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
|
||||
MS_EXCEPTION_IF_NULL(bias);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedBatchNormGradExWithAddAndActivation);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y};
|
||||
auto fused_batch_norm_add_relu_grad = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(fused_batch_norm_add_relu_grad);
|
||||
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad);
|
||||
SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad);
|
||||
ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad);
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BatchNormAddReluGradFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit BatchNormAddReluGradFusion(bool multigraph = true)
|
||||
: PatternProcessPass("batch_norm_add_relu_grad_fusion", multigraph) {
|
||||
dy_ = std::make_shared<Var>();
|
||||
y_ = std::make_shared<Var>();
|
||||
x_ = std::make_shared<Var>();
|
||||
scale_ = std::make_shared<Var>();
|
||||
bias_ = std::make_shared<Var>();
|
||||
mean_ = std::make_shared<Var>();
|
||||
var_ = std::make_shared<Var>();
|
||||
save_mean_ = std::make_shared<Var>();
|
||||
save_var_ = std::make_shared<Var>();
|
||||
reserve_ = std::make_shared<Var>();
|
||||
}
|
||||
~BatchNormAddReluGradFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr dy_;
|
||||
VarPtr y_;
|
||||
VarPtr x_;
|
||||
VarPtr scale_;
|
||||
VarPtr bias_;
|
||||
VarPtr mean_;
|
||||
VarPtr var_;
|
||||
VarPtr save_mean_;
|
||||
VarPtr save_var_;
|
||||
VarPtr reserve_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_
|
|
@ -26,11 +26,14 @@
|
|||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
|
||||
#include "backend/optimizer/gpu/adam_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_bn_cast_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h"
|
||||
#include "backend/optimizer/gpu/batch_norm_relu_fusion.h"
|
||||
#include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h"
|
||||
#include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h"
|
||||
#include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
|
||||
#include "backend/optimizer/gpu/replace_addn_fusion.h"
|
||||
#include "backend/optimizer/gpu/insert_format_transform_op.h"
|
||||
|
@ -73,6 +76,8 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
|
||||
pm->AddPass(std::make_shared<opt::AdamFusion>());
|
||||
// pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
||||
// pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceBNCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
||||
|
@ -81,6 +86,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
||||
// pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
|
||||
}
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -193,6 +193,8 @@ constexpr auto kPaddingOpName = "Padding";
|
|||
constexpr auto kAvgPoolOpName = "AvgPool";
|
||||
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
|
||||
constexpr auto kTensorAddOpName = "TensorAdd";
|
||||
constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum";
|
||||
constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";
|
||||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
|
Loading…
Reference in New Issue