!5612 Resnet50 pattern Fusion

Merge pull request !5612 from chenweifeng/BatchNormAddReluGrad
This commit is contained in:
mindspore-ci-bot 2020-09-03 10:52:05 +08:00 committed by Gitee
commit 1944b8e53b
14 changed files with 810 additions and 14 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,8 +26,7 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *
}
template <>
__global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, half *accumulation,
const float *learning_rate, const half *gradient,
const float *momentum) {
const float *learning_rate, const half *gradient, const float *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i];
variable[i] -= __float2half(learning_rate[0]) * accumulation[i];
@ -36,8 +35,7 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable,
}
template <>
__global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation,
const float *learning_rate, const half *gradient,
const float *momentum) {
const float *learning_rate, const half *gradient, const float *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]);
variable[i] -= learning_rate[0] * accumulation[i];
@ -51,15 +49,68 @@ void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, con
learning_rate, gradient, momentum);
return;
}
template <typename T, typename S>
__global__ void FusedMomentumWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable,
T *accumulation, const T *learning_rate, const S *gradient,
const T *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) {
T grad = (variable[i] * weight_decay[0] + static_cast<T>(gradient[i])) * scale[0];
accumulation[i] = momentum[0] * accumulation[i] + grad;
variable[i] -= learning_rate[0] * accumulation[i];
}
}
template <typename T, typename S>
void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, T *accumulation,
const T *learning_rate, const S *gradient, const T *momentum,
cudaStream_t cuda_stream) {
size_t thread_per_block = 256;
size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block;
FusedMomentumWeightDecayScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
element_num, weight_decay, scale, variable, accumulation, learning_rate, gradient, momentum);
}
template <typename T, typename S>
__global__ void FusedMomentumScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation,
const T *learning_rate, const S *gradient, const T *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (element_num); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + static_cast<T>(gradient[i]);
variable[i] -= learning_rate[0] * accumulation[i];
}
}
template <typename T, typename S>
void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate,
const S *gradient, const T *momentum, cudaStream_t cuda_stream) {
size_t thread_per_block = 256;
size_t block_per_grid = (element_num + thread_per_block - 1) / thread_per_block;
FusedMomentumScaleMomentum<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(
element_num, scale, variable, accumulation, learning_rate, gradient, momentum);
}
template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation,
const float *learning_rate, const float *gradient,
const float *momentum, cudaStream_t cuda_stream);
const float *learning_rate, const float *gradient,
const float *momentum, cudaStream_t cuda_stream);
template void MomentumUpdateVariable<half, half, half>(const size_t size, half *variable, half *accumulation,
const half *learning_rate, const half *gradient,
const half *momentum, cudaStream_t cuda_stream);
const half *learning_rate, const half *gradient,
const half *momentum, cudaStream_t cuda_stream);
template void MomentumUpdateVariable<half, float, half>(const size_t size, half *variable, half *accumulation,
const float *learning_rate, const half *gradient,
const float *momentum, cudaStream_t cuda_stream);
const float *learning_rate, const half *gradient,
const float *momentum, cudaStream_t cuda_stream);
template void MomentumUpdateVariable<float, float, half>(const size_t size, float *variable, float *accumulation,
const float *learning_rate, const half *gradient,
const float *momentum, cudaStream_t cuda_stream);
const float *learning_rate, const half *gradient,
const float *momentum, cudaStream_t cuda_stream);
template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale,
float *variable, float *accumulation, const float *learning_rate,
const float *gradient, const float *momentum, cudaStream_t cuda_stream);
template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale,
float *variable, float *accumulation, const float *learning_rate,
const half *gradient, const float *momentum, cudaStream_t cuda_stream);
template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation,
const float *learning_rate, const float *gradient, const float *momentum,
cudaStream_t cuda_stream);
template void FusedScaleMomentum(const size_t element_num, float *scale, float *variable, float *accumulation,
const float *learning_rate, const half *gradient, const float *momentum,
cudaStream_t cuda_stream);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,5 +21,12 @@
template <typename T, typename S, typename G>
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient,
const S *momentum, cudaStream_t cuda_stream);
template <typename T, typename S>
void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, T *accumulation,
const T *learning_rate, const S *gradient, const T *momentum,
cudaStream_t cuda_stream);
template <typename T, typename S>
void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accumulation, const T *learning_rate,
const S *gradient, const T *momentum, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/fused_scale_momentum_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(FusedScaleApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // learning_rate
.AddInputAttr(kNumberTypeFloat32) // gradient
.AddInputAttr(kNumberTypeFloat32) // momentum
.AddOutputAttr(kNumberTypeFloat32),
FusedScaleMomentumGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(FusedScaleApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // learning_rate
.AddInputAttr(kNumberTypeFloat16) // gradient
.AddInputAttr(kNumberTypeFloat32) // momentum
.AddOutputAttr(kNumberTypeFloat32),
FusedScaleMomentumGpuKernel, float, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class FusedScaleMomentumGpuKernel : public GpuKernel {
public:
FusedScaleMomentumGpuKernel() : element_num_(1) {}
~FusedScaleMomentumGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
void *stream_ptr) override {
T *scale = GetDeviceAddress<T>(inputs, 0);
T *variable = GetDeviceAddress<T>(inputs, 1);
T *accumulation = GetDeviceAddress<T>(inputs, 2);
T *learning_rate = GetDeviceAddress<T>(inputs, 3);
S *gradient = GetDeviceAddress<S>(inputs, 4);
T *momentum = GetDeviceAddress<T>(inputs, 5);
FusedScaleMomentum(element_num_, scale, variable, accumulation, learning_rate, gradient, momentum,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 6) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 6 inputs.";
return false;
}
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < variable_shape.size(); i++) {
element_num_ *= variable_shape[i];
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(T));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(S));
input_size_list_.push_back(sizeof(T));
output_size_list_.push_back(element_num_ * sizeof(T));
}
private:
size_t element_num_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_SCALE_MOMENTUM_GPU_KERNEL_H_

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/fused_weightdecay_scale_momentum_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(FusedWeightScaleApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // weight decay
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // learning_rate
.AddInputAttr(kNumberTypeFloat32) // gradient
.AddInputAttr(kNumberTypeFloat32) // momentum
.AddOutputAttr(kNumberTypeFloat32),
FusedWeightDecayScaleMomentumGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(FusedWeightScaleApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // variable
.AddInputAttr(kNumberTypeFloat32) // accumulation
.AddInputAttr(kNumberTypeFloat32) // learning_rate
.AddInputAttr(kNumberTypeFloat16) // gradient
.AddInputAttr(kNumberTypeFloat32) // momentum
.AddOutputAttr(kNumberTypeFloat32),
FusedWeightDecayScaleMomentumGpuKernel, float, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,87 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class FusedWeightDecayScaleMomentumGpuKernel : public GpuKernel {
public:
FusedWeightDecayScaleMomentumGpuKernel() : element_num_(1) {}
~FusedWeightDecayScaleMomentumGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
void *stream_ptr) override {
T *weight_decay = GetDeviceAddress<T>(inputs, 0);
T *scale = GetDeviceAddress<T>(inputs, 1);
T *variable = GetDeviceAddress<T>(inputs, 2);
T *accumulation = GetDeviceAddress<T>(inputs, 3);
T *learning_rate = GetDeviceAddress<T>(inputs, 4);
S *gradient = GetDeviceAddress<S>(inputs, 5);
T *momentum = GetDeviceAddress<T>(inputs, 6);
FusedWeightDecayScaleMomentum(element_num_, weight_decay, scale, variable, accumulation, learning_rate, gradient,
momentum, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 7) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but FusedMomentum needs 7 inputs.";
return false;
}
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < variable_shape.size(); i++) {
element_num_ *= variable_shape[i];
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(T));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(S));
input_size_list_.push_back(sizeof(T));
output_size_list_.push_back(element_num_ * sizeof(T));
}
private:
size_t element_num_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_WEIGHTDECAY_SCALE_MOMENTUM_GPU_KERNEL_H_

View File

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
const BaseRef ApplyMomentumScaleFusion::DefinePattern() const {
VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_});
VectorRef apply_momentum =
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_});
return apply_momentum;
}
const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto scale = utils::cast<AnfNodePtr>((*equiv)[scale_]);
auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]);
auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]);
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
MS_EXCEPTION_IF_NULL(scale);
MS_EXCEPTION_IF_NULL(variable);
MS_EXCEPTION_IF_NULL(accumulation);
MS_EXCEPTION_IF_NULL(learning_rate);
MS_EXCEPTION_IF_NULL(gradient);
MS_EXCEPTION_IF_NULL(momentum);
auto prim = std::make_shared<Primitive>(kFusedScaleApplyMomentum);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), scale, variable, accumulation,
learning_rate, gradient, momentum};
auto replace_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(replace_node);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get());
replace_node->set_scope(node->scope());
return replace_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class ApplyMomentumScaleFusion : public PatternProcessPass {
public:
explicit ApplyMomentumScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_scale_fusion", multigraph) {
scale_ = std::make_shared<Var>();
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
momentum_ = std::make_shared<Var>();
}
~ApplyMomentumScaleFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr scale_;
VarPtr variable_;
VarPtr accumulation_;
VarPtr learning_rate_;
VarPtr gradient_;
VarPtr momentum_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_SCALE_FUSION_H_

View File

@ -0,0 +1,71 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
VectorRef weight = VectorRef(
{prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});
VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_});
VectorRef apply_momentum =
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_});
return apply_momentum;
}
const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto weight_decay = utils::cast<AnfNodePtr>((*equiv)[weight_decay_]);
auto scale = utils::cast<AnfNodePtr>((*equiv)[scale_]);
auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]);
auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]);
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
MS_EXCEPTION_IF_NULL(weight_decay);
MS_EXCEPTION_IF_NULL(scale);
MS_EXCEPTION_IF_NULL(variable);
MS_EXCEPTION_IF_NULL(accumulation);
MS_EXCEPTION_IF_NULL(learning_rate);
MS_EXCEPTION_IF_NULL(gradient);
MS_EXCEPTION_IF_NULL(momentum);
auto prim = std::make_shared<Primitive>(kFusedWeightScaleApplyMomentum);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable,
accumulation, learning_rate, gradient, momentum};
auto replace_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(replace_node);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, replace_node.get());
replace_node->set_scope(node->scope());
return replace_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
public:
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true)
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) {
weight_decay_ = std::make_shared<Var>();
scale_ = std::make_shared<Var>();
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
momentum_ = std::make_shared<Var>();
}
~ApplyMomentumWeightDecayScaleFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr weight_decay_;
VarPtr scale_;
VarPtr variable_;
VarPtr accumulation_;
VarPtr learning_rate_;
VarPtr gradient_;
VarPtr momentum_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_APPLY_MOMENTUM_WEIGHT_DECAY_SCALE_FUSION_H_

View File

@ -0,0 +1,175 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h"
#include <algorithm>
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
const std::vector<int> kOutputIndex{0, 1, 2};
constexpr size_t kBNGradOutputNum = 3;
constexpr size_t kBNAddReluGradOutputNum = 4;
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn_outputs);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
if (manager->node_users().find(bn) == manager->node_users().end()) {
return false;
}
size_t output_num = 0;
for (const auto &node_index : manager->node_users()[bn]) {
const AnfNodePtr &output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
continue;
}
auto tuple_getiterm_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(index_node);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int index = GetValue<int>(value_node->value());
if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) {
return false;
}
bn_outputs->push_back(output);
output_num++;
}
return output_num == kBNGradOutputNum;
}
void SetShapeAndType(const CNodePtr &bn_add_relu_grad, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad) {
// set output shape and dtype
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto output_num = AnfAlgo::GetOutputTensorNum(bn_grad);
for (size_t i = 0; i < output_num; ++i) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(bn_grad, i));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(bn_grad, i));
}
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, 0));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(relu_grad, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, bn_add_relu_grad.get());
}
void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const AnfNodePtr &relu_grad,
const CNodePtr &bn_add_relu_grad) {
// Create outputs
std::vector<AnfNodePtr> bn_add_relu_grad_output;
CreateMultipleOutputsOfAnfNode(graph, bn_add_relu_grad, kBNAddReluGradOutputNum, &bn_add_relu_grad_output);
if (bn_add_relu_grad_output.size() != kBNAddReluGradOutputNum) {
MS_LOG(EXCEPTION) << "The output size of node " << kFusedBatchNormGradExWithAddAndActivation << " must be "
<< kBNAddReluGradOutputNum << ", but it is " << bn_add_relu_grad_output.size();
}
// Get bn outputs
std::vector<AnfNodePtr> bn_outputs;
if (!GetBatchNormOutputs(graph, bn_grad, &bn_outputs)) {
MS_LOG(INFO) << "The " << prim::kPrimFusedBatchNormGradEx
<< " node should only have output 0, 1 and 2. The node should not be changed";
return;
}
// Replace orignal output
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem);
size_t output_index = 0;
for (const auto &output : bn_outputs) {
(void)manager->Replace(output, bn_add_relu_grad_output[output_index]);
output_index++;
}
manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]);
return;
}
} // namespace
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
VectorRef batch_norm_grad =
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
return batch_norm_grad;
}
const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) {
return nullptr;
}
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad);
auto relu_users = GetRealNodeUsedList(graph, relu_grad);
if (relu_users->size() != 2) {
return nullptr;
}
// process pattern as Relu(TensorAdd(BN#0, BN#1))
auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
MS_EXCEPTION_IF_NULL(tuple_getitem);
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
return nullptr;
}
auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0);
MS_EXCEPTION_IF_NULL(dy);
auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1);
MS_EXCEPTION_IF_NULL(y);
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 1);
MS_EXCEPTION_IF_NULL(x);
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 2);
MS_EXCEPTION_IF_NULL(scale);
auto save_mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3);
MS_EXCEPTION_IF_NULL(save_mean);
auto save_var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 4);
MS_EXCEPTION_IF_NULL(save_var);
auto reserve = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
MS_EXCEPTION_IF_NULL(reserve);
auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(save_mean), 0);
MS_EXCEPTION_IF_NULL(batch_norm);
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
MS_EXCEPTION_IF_NULL(bias);
auto prim = std::make_shared<Primitive>(kFusedBatchNormGradExWithAddAndActivation);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y};
auto fused_batch_norm_add_relu_grad = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(fused_batch_norm_add_relu_grad);
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad);
SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad);
ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad);
return nullptr;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_ADD_RELU_GRAD_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class BatchNormAddReluGradFusion : public PatternProcessPass {
public:
explicit BatchNormAddReluGradFusion(bool multigraph = true)
: PatternProcessPass("batch_norm_add_relu_grad_fusion", multigraph) {
dy_ = std::make_shared<Var>();
y_ = std::make_shared<Var>();
x_ = std::make_shared<Var>();
scale_ = std::make_shared<Var>();
bias_ = std::make_shared<Var>();
mean_ = std::make_shared<Var>();
var_ = std::make_shared<Var>();
save_mean_ = std::make_shared<Var>();
save_var_ = std::make_shared<Var>();
reserve_ = std::make_shared<Var>();
}
~BatchNormAddReluGradFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr dy_;
VarPtr y_;
VarPtr x_;
VarPtr scale_;
VarPtr bias_;
VarPtr mean_;
VarPtr var_;
VarPtr save_mean_;
VarPtr save_var_;
VarPtr reserve_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BATCH_NORM_RELU_GRAD_FUSION_H_

View File

@ -26,11 +26,14 @@
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
#include "backend/optimizer/gpu/adam_fusion.h"
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
#include "backend/optimizer/gpu/replace_bn_cast_fusion.h"
#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h"
#include "backend/optimizer/gpu/batch_norm_relu_fusion.h"
#include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h"
#include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h"
#include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h"
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
#include "backend/optimizer/gpu/replace_addn_fusion.h"
#include "backend/optimizer/gpu/insert_format_transform_op.h"
@ -73,6 +76,8 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
pm->AddPass(std::make_shared<opt::AdamFusion>());
// pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
// pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
pm->AddPass(std::make_shared<opt::ReplaceBNCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
@ -81,6 +86,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
// pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
}
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);

View File

@ -193,6 +193,8 @@ constexpr auto kPaddingOpName = "Padding";
constexpr auto kAvgPoolOpName = "AvgPool";
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
constexpr auto kTensorAddOpName = "TensorAdd";
constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum";
constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";
// attr key name
constexpr auto kAttrInputNames = "input_names";