diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cu new file mode 100644 index 00000000000..dfadaa09d6c --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cu @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "adam_weight_decay_impl.cuh" +#include "device/gpu/cuda_common.h" + +template +__global__ void AdamWeightDecayKernel(const int element_num_, const bool need_decay, const float *beta1, + const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, + const float *epsilon, const float *lr, const float *weight_decay, T *m, T *v, + T *param, T *gradient) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num_; i += blockDim.x * gridDim.x) { + float next_m = beta1[0] * m[i] + one_sub_beta1[0] * gradient[i]; + float next_v = beta2[0] * v[i] + one_sub_beta2[0] * gradient[i] * gradient[i]; + float update = next_m / (sqrt(next_v) + epsilon[0]); + if (need_decay && weight_decay != nullptr) { + update += weight_decay[0] * param[i]; + } + param[i] -= lr[0] * update; + m[i] = next_m; + v[i] = next_v; + } +} + +template +void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1, + const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr, + const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream) { + AdamWeightDecayKernel<<>>( + element_num_, need_decay, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, param, + gradient); +} + +template void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, + const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, + const float *epsilon, const float *lr, const float *weight_decay, float *m, float *v, + float *param, float *gradient, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh new file mode 100644 index 00000000000..2addffbf002 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_ +template +void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1, + const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr, + const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ADAM_WEIGHT_DECAY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc new file mode 100644 index 00000000000..77cb7f86086 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/nn/fused_adam_weight_decay.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FusedAdamWeightDecay, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedAdamWeightDecayGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FusedAdam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedAdamWeightDecayGpuKernel, float) + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.h b/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.h new file mode 100644 index 00000000000..f13f6ed59fb --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/kernel_constants.h" +#include "kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class FusedAdamWeightDecayGpuKernel : public GpuKernel { + public: + FusedAdamWeightDecayGpuKernel() : element_nums_(0), weight_decay_(false) {} + ~FusedAdamWeightDecayGpuKernel() override = default; + + bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (node_name == "AdamWeighDecay") { + weight_decay_ = true; + } + + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 7); + element_nums_ = 1; + for (auto i : shape) { + element_nums_ *= i; + } + + InitSizeLists(); + return true; + } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + float *beta1 = GetDeviceAddress(inputs, 0); + float *one_sub_beta1 = GetDeviceAddress(inputs, 1); + float *beta2 = GetDeviceAddress(inputs, 2); + float *one_sub_beta2 = GetDeviceAddress(inputs, 3); + float *epsilon = GetDeviceAddress(inputs, 4); + float *lr = GetDeviceAddress(inputs, 5); + T *param = GetDeviceAddress(inputs, 6); + T *m = GetDeviceAddress(inputs, 7); + T *v = GetDeviceAddress(inputs, 8); + T *gradient = GetDeviceAddress(inputs, 9); + float *weight_decay = nullptr; + if (weight_decay_) { + weight_decay = GetDeviceAddress(inputs, 10); + } + AdamWeightDecay(element_nums_, true, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, + param, gradient, reinterpret_cast(stream_ptr)); + return true; + } + + protected: + void InitResource() override{}; + void InitSizeLists() override { + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(element_nums_ * sizeof(T)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(element_nums_ * sizeof(T)); + if (weight_decay_) { + input_size_list_.push_back(sizeof(float)); + } + output_size_list_.push_back(element_nums_ * sizeof(T)); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int element_nums_; + bool weight_decay_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 01812a55291..522b80def6c 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -182,9 +182,11 @@ extern const PrimitivePtr kPrimReduceMin; extern const PrimitivePtr kPrimNeg; extern const PrimitivePtr kPrimSub; extern const PrimitivePtr kPrimMul; +extern const PrimitivePtr kPrimRealDiv; extern const PrimitivePtr kPrimMinimum; extern const PrimitivePtr kPrimMaximum; extern const PrimitivePtr kPrimSquare; +extern const PrimitivePtr kPrimSqrt; extern const PrimitivePtr kPrimEqual; extern const PrimitivePtr kPrimLess; extern const PrimitivePtr kPrimLessEqual; diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_fusion.cc b/mindspore/ccsrc/pre_activate/gpu/adam_fusion.cc new file mode 100644 index 00000000000..8111ee429d9 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/gpu/adam_fusion.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/gpu/adam_fusion.h" + +#include +#include +#include + +#include "session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "pre_activate/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { + std::vector inputs_format; + std::vector outputs_format; + std::vector inputs_type; + std::vector outputs_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); + inputs_format.push_back(kOpFormat_DEFAULT); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); + outputs_format.push_back(kOpFormat_DEFAULT); + } + builder.SetInputsDeviceType(inputs_type); + builder.SetInputsFormat(inputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetOutputsFormat(outputs_format); + return builder.Build(); +} +} // namespace + +const BaseRef AdamFusion::DefinePattern() const { + VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), + VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); + VectorRef next_v = + VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), + VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); + VectorRef update = VectorRef( + {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); + VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); + VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); + VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); + VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); + return depend3; +} + +const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto beta1_input = utils::cast((*equiv)[beta1_]); + auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); + auto beta2_input = utils::cast((*equiv)[beta2_]); + auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); + auto eps_input = utils::cast((*equiv)[eps_]); + auto lr_input = utils::cast((*equiv)[lr_]); + auto param_input = utils::cast((*equiv)[param_]); + auto m_input = utils::cast((*equiv)[m_]); + auto v_input = utils::cast((*equiv)[v_]); + auto gradient_input = utils::cast((*equiv)[gradient_]); + MS_EXCEPTION_IF_NULL(beta1_input); + MS_EXCEPTION_IF_NULL(one_sub_beta1_input); + MS_EXCEPTION_IF_NULL(beta2_input); + MS_EXCEPTION_IF_NULL(one_sub_beta2_input); + MS_EXCEPTION_IF_NULL(eps_input); + MS_EXCEPTION_IF_NULL(lr_input); + MS_EXCEPTION_IF_NULL(param_input); + MS_EXCEPTION_IF_NULL(m_input); + MS_EXCEPTION_IF_NULL(v_input); + MS_EXCEPTION_IF_NULL(gradient_input); + + auto prim = std::make_shared(kFusedAdamName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = { + NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, + eps_input, lr_input, param_input, m_input, v_input, + gradient_input}; + auto adam = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(adam); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get()); + adam->set_scope(node->scope()); + + auto build_info = GenerateKernelBuildInfo(adam); + AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); + return adam; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_fusion.h b/mindspore/ccsrc/pre_activate/gpu/adam_fusion.h new file mode 100644 index 00000000000..d8c10a0986a --- /dev/null +++ b/mindspore/ccsrc/pre_activate/gpu/adam_fusion.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ + +#include +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class AdamFusion : public PatternProcessPass { + public: + explicit AdamFusion(bool multigraph = true) : PatternProcessPass("adam_fusion", multigraph) { + beta1_ = std::make_shared(); + one_sub_beta1_ = std::make_shared(); + beta2_ = std::make_shared(); + one_sub_beta2_ = std::make_shared(); + eps_ = std::make_shared(); + lr_ = std::make_shared(); + param_ = std::make_shared(); + m_ = std::make_shared(); + v_ = std::make_shared(); + gradient_ = std::make_shared(); + } + ~AdamFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr beta1_; + VarPtr one_sub_beta1_; + VarPtr beta2_; + VarPtr one_sub_beta2_; + VarPtr eps_; + VarPtr lr_; + VarPtr param_; + VarPtr m_; + VarPtr v_; + VarPtr gradient_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.cc b/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.cc new file mode 100644 index 00000000000..c950cbd56fd --- /dev/null +++ b/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/gpu/adam_weight_decay_fusion.h" + +#include +#include +#include + +#include "session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "pre_activate/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { + std::vector inputs_format; + std::vector outputs_format; + std::vector inputs_type; + std::vector outputs_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); + inputs_format.push_back(kOpFormat_DEFAULT); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); + outputs_format.push_back(kOpFormat_DEFAULT); + } + builder.SetInputsDeviceType(inputs_type); + builder.SetInputsFormat(inputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetOutputsFormat(outputs_format); + return builder.Build(); +} +} // namespace + +const BaseRef AdamWeightDecayFusion::DefinePattern() const { + VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), + VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); + VectorRef next_v = + VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), + VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); + VectorRef update = VectorRef( + {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); + VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); + + VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); + VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); + VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); + VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); + return depend3; +} + +const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto beta1_input = utils::cast((*equiv)[beta1_]); + auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); + auto beta2_input = utils::cast((*equiv)[beta2_]); + auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); + auto eps_input = utils::cast((*equiv)[eps_]); + auto lr_input = utils::cast((*equiv)[lr_]); + auto weight_decay_input = utils::cast((*equiv)[weight_decay_]); + auto param_input = utils::cast((*equiv)[param_]); + auto m_input = utils::cast((*equiv)[m_]); + auto v_input = utils::cast((*equiv)[v_]); + auto gradient_input = utils::cast((*equiv)[gradient_]); + MS_EXCEPTION_IF_NULL(beta1_input); + MS_EXCEPTION_IF_NULL(one_sub_beta1_input); + MS_EXCEPTION_IF_NULL(beta2_input); + MS_EXCEPTION_IF_NULL(one_sub_beta2_input); + MS_EXCEPTION_IF_NULL(eps_input); + MS_EXCEPTION_IF_NULL(lr_input); + MS_EXCEPTION_IF_NULL(weight_decay_input); + MS_EXCEPTION_IF_NULL(param_input); + MS_EXCEPTION_IF_NULL(m_input); + MS_EXCEPTION_IF_NULL(v_input); + MS_EXCEPTION_IF_NULL(gradient_input); + + auto prim = std::make_shared(kFusedAdamWeightDecayName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = { + NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, + eps_input, lr_input, param_input, m_input, v_input, + gradient_input, weight_decay_input}; + auto adam_weight_decay = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(adam_weight_decay); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam_weight_decay.get()); + adam_weight_decay->set_scope(node->scope()); + + auto build_info = GenerateKernelBuildInfo(adam_weight_decay); + AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get()); + return adam_weight_decay; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.h b/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.h new file mode 100644 index 00000000000..0ada5756e30 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ + +#include +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class AdamWeightDecayFusion : public PatternProcessPass { + public: + explicit AdamWeightDecayFusion(bool multigraph = true) : PatternProcessPass("adam_weight_decay_fusion", multigraph) { + beta1_ = std::make_shared(); + one_sub_beta1_ = std::make_shared(); + beta2_ = std::make_shared(); + one_sub_beta2_ = std::make_shared(); + eps_ = std::make_shared(); + lr_ = std::make_shared(); + weight_decay_ = std::make_shared(); + param_ = std::make_shared(); + m_ = std::make_shared(); + v_ = std::make_shared(); + gradient_ = std::make_shared(); + } + ~AdamWeightDecayFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr beta1_; + VarPtr one_sub_beta1_; + VarPtr beta2_; + VarPtr one_sub_beta2_; + VarPtr eps_; + VarPtr lr_; + VarPtr weight_decay_; + VarPtr param_; + VarPtr m_; + VarPtr v_; + VarPtr gradient_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ diff --git a/mindspore/ccsrc/session/gpu_session.cc b/mindspore/ccsrc/session/gpu_session.cc index a0a43f2edda..85ad2f3d1e2 100644 --- a/mindspore/ccsrc/session/gpu_session.cc +++ b/mindspore/ccsrc/session/gpu_session.cc @@ -23,6 +23,8 @@ #include "pre_activate/common/helper.h" #include "pre_activate/pass/communication_op_fusion.h" #include "pre_activate/pass/getitem_tuple.h" +#include "pre_activate/gpu/adam_weight_decay_fusion.h" +#include "pre_activate/gpu/adam_fusion.h" #include "device/kernel_runtime_manager.h" #include "predict/predict.h" #include "common/utils.h" @@ -53,6 +55,16 @@ void GPUSession::StartKernelRT() const { void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void GPUSession::HardwareOptimize(const std::shared_ptr &kernel_graph) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -151,14 +163,16 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList auto graph_id = graph_sum_; auto graph = ConstructKernelGraph(lst, outputs); MS_EXCEPTION_IF_NULL(graph); + // Optimize + Optimize(graph); // Select kernel build info SelectKernel(graph); // Convert kernel Graph to model predictmodel::StepConvertGraph(graph); // Start gpu kernel runtime StartKernelRT(); - // AllReduce Optimize - Optimize(graph); + // HardwareOptimize + HardwareOptimize(graph); // Assign CUDA streams AssignStream(graph); // Hide NoOp from execution graph diff --git a/mindspore/ccsrc/session/gpu_session.h b/mindspore/ccsrc/session/gpu_session.h index 0dfb815abe7..4e46c2138d0 100644 --- a/mindspore/ccsrc/session/gpu_session.h +++ b/mindspore/ccsrc/session/gpu_session.h @@ -51,6 +51,8 @@ class GPUSession : public SessionBasic { void Optimize(const std::shared_ptr &kernel_graph); + void HardwareOptimize(const std::shared_ptr &kernel_graph); + void AssignStream(const std::shared_ptr &kernel_graph); void BuildKernel(const std::shared_ptr &kernel_graph) const; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 7380ef501f7..f80c13c9a19 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -161,6 +161,8 @@ constexpr auto kNMSWithMaskOpName = "NMSWithMask"; constexpr auto kSoftmaxGradExtOpName = "SoftmaxGradExt"; constexpr auto kStridedReadOpName = "StridedRead"; constexpr auto kStridedWriteOpName = "StridedWrite"; +constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay"; +constexpr auto kFusedAdamName = "FusedAdam"; // attr key name constexpr auto kAttrInputNames = "input_names"; diff --git a/tests/st/ops/gpu/test_adam_fusion.py b/tests/st/ops/gpu/test_adam_fusion.py new file mode 100644 index 00000000000..f0595d12a1e --- /dev/null +++ b/tests/st/ops/gpu/test_adam_fusion.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) + + +class Net(nn.Cell): + def __init__(self, decay_flag=True): + super(Net, self).__init__() + self.decay_flag = decay_flag + self.op_mul = P.Mul() + self.op_square = P.Square() + self.op_sqrt = P.Sqrt() + self.op_cast = P.Cast() + self.op_reshape = P.Reshape() + self.op_shape = P.Shape() + self.param = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='param') + self.m = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='m') + self.v = Parameter(Tensor(np.array([0.1, 0.3, 0.5]).astype(np.float32)), name='v') + + @ms_function + def construct(self, beta1, beta2, gradient, eps, weight_decay_tensor, lr): + param_fp32 = self.op_cast(self.param, mstype.float32) + m_fp32 = self.op_cast(self.m, mstype.float32) + v_fp32 = self.op_cast(self.v, mstype.float32) + gradient_fp32 = self.op_cast(gradient, mstype.float32) + + next_m = self.op_mul(beta1, m_fp32) + \ + self.op_mul(self.op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) + next_v = self.op_mul(beta2, v_fp32) + self.op_mul(self.op_cast(F.tuple_to_array((1.0,)), mstype.float32) - \ + beta2, self.op_square(gradient_fp32)) + update = next_m / (eps + self.op_sqrt(next_v)) + if self.decay_flag: + update = self.op_mul(weight_decay_tensor, param_fp32) + update + update_with_lr = self.op_mul(lr, update) + next_param = param_fp32 - self.op_reshape(update_with_lr, self.op_shape(param_fp32)) + + next_v = F.depend(next_v, F.assign(self.param, next_param)) + next_v = F.depend(next_v, F.assign(self.m, next_m)) + next_v = F.depend(next_v, F.assign(self.v, next_v)) + return next_v + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test(): + beta1 = Tensor(np.array([0.9]).astype(np.float32)) + beta2 = Tensor(np.array([0.999]).astype(np.float32)) + lr = Tensor(np.array([0.001]).astype(np.float32)) + eps = Tensor(np.array([1e-6]).astype(np.float32)) + weight_decay_tensor = Tensor(np.array([0.001]).astype(np.float32)) + + gradient = Tensor(np.array([0.01, 0.03, 0.05]).astype(np.float32)) + opt = Net(True) + _ = opt(beta1, beta2, gradient, eps, weight_decay_tensor, lr) + + param_expect = np.array([0.09971199, 0.29950103, 0.4993557]).astype(np.float32) + m_expect = np.array([0.091, 0.273, 0.45499998]).astype(np.float32) + v_expect = np.array([0.0999001, 0.29970092, 0.4995025]).astype(np.float32) + assert np.allclose(opt.param.data.asnumpy(), param_expect) + assert np.allclose(opt.m.data.asnumpy(), m_expect) + assert np.allclose(opt.v.data.asnumpy(), v_expect) diff --git a/tests/st/ops/gpu/test_batch_matmul.py b/tests/st/ops/gpu/test_batch_matmul.py index e8450bd81d4..7dbab738011 100644 --- a/tests/st/ops/gpu/test_batch_matmul.py +++ b/tests/st/ops/gpu/test_batch_matmul.py @@ -119,6 +119,10 @@ def test_4d_transpose_ab(): [[5612, 5810, 6008, 6206]]]] assert (output.asnumpy() == expect).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard def test_4D_fp16(): input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float16) input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float16) @@ -126,13 +130,13 @@ def test_4D_fp16(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") net = BatchMatMulNet() output = net(input_x, input_y) - expect = [[[[20, 23, 26, 29]], - [[200, 212, 224, 236]], - [[596, 617, 638, 659]], - [[1208, 1238, 1268, 1298]]], + expect = np.array([[[[20, 23, 26, 29]], + [[200, 212, 224, 236]], + [[596, 617, 638, 659]], + [[1208, 1238, 1268, 1298]]], - [[[2036, 2075, 2114, 2153]], - [[3080, 3128, 3176, 3224]], - [[4340, 4397, 4454, 4511]], - [[5816, 5882, 5948, 6014]]]] + [[[2036, 2076, 2114, 2152]], + [[3080, 3128, 3176, 3224]], + [[4340, 4396, 4456, 4510]], + [[5816, 5880, 5948, 6016]]]]).astype(np.float16) assert (output.asnumpy() == expect).all()