From 0a920057a5e9d6d76974f8f9849c4adabfe94e55 Mon Sep 17 00:00:00 2001 From: zhaosida Date: Mon, 26 Apr 2021 19:35:06 +0800 Subject: [PATCH] add AdamWeightDecay CPU op and optimize with SIMD intinsics --- build.sh | 6 +- .../backend/kernel_compiler/CMakeLists.txt | 12 ++ .../cpu/adam_weight_decay_cpu_kernel.cc | 146 ++++++++++++++++ .../cpu/adam_weight_decay_cpu_kernel.h | 97 +++++++++++ mindspore/ops/operations/inner_ops.py | 2 +- model_zoo/official/nlp/gpt/src/adam.py | 162 ++++++++++++++++++ tests/st/ops/cpu/test_adam_weight_decay_op.py | 66 +++++++ 7 files changed, 489 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h create mode 100644 model_zoo/official/nlp/gpt/src/adam.py create mode 100644 tests/st/ops/cpu/test_adam_weight_decay_op.py diff --git a/build.sh b/build.sh index ec3d8ac530b..1aa25301af4 100755 --- a/build.sh +++ b/build.sh @@ -119,6 +119,7 @@ checkopts() ANDROID_STL="c++_shared" ENABLE_MAKE_CLEAN="off" X86_64_SIMD="off" + ARM_SIMD="off" DEVICE_VERSION="" DEVICE="" ENABLE_NPU="off" @@ -331,6 +332,9 @@ checkopts() if [[ "$OPTARG" == "sse" || "$OPTARG" == "avx" ]]; then X86_64_SIMD="$OPTARG" fi + if [[ "$OPTARG" == "neon" ]]; then + ARM_SIMD="$OPTARG" + fi ;; H) check_on_off $OPTARG H @@ -474,7 +478,7 @@ build_mindspore() CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GPU=ON -DUSE_CUDA=ON -DCUDA_PATH=$CUDA_PATH -DMS_REQUIRE_CUDA_VERSION=${CUDA_VERSION}" fi if [[ "X$ENABLE_CPU" = "Xon" ]]; then - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPU=ON -DX86_64_SIMD=${X86_64_SIMD}" + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPU=ON -DX86_64_SIMD=${X86_64_SIMD} -DARM_SIMD=${ARM_SIMD}" fi if [[ "X$COMPILE_MINDDATA" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_MINDDATA=ON" diff --git a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt index 704489e53a9..593faa29d57 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt +++ b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt @@ -61,6 +61,18 @@ if(ENABLE_CPU) message("not compiled quantum kernel_compiler") set(QUANTUM_SRC_LIST "") endif() + + if("${ARM_SIMD}" STREQUAL "neon") + set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc") + add_compile_definitions(ENABLE_NEON) + set_property(SOURCE ${CPU_SIMD_SRC} PROPERTY COMPILE_OPTIONS -O3 -ffast-math) + endif() + + if("${X86_64_SIMD}" STREQUAL "avx") + set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc") + add_compile_definitions(ENABLE_AVX512) + set_property(SOURCE ${CPU_SIMD_SRC} PROPERTY COMPILE_OPTIONS -O3 -fopenmp -mavx512f -ffast-math) + endif() endif() if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc new file mode 100644 index 00000000000..91706fa6f6f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.cc @@ -0,0 +1,146 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h" + +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "utils/ms_utils.h" + +namespace mindspore { +namespace kernel { +template +void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(T *var, T *m, T *v, float lr, float beta1, float beta2, + float epsilon, T *decay, const T *gradient, size_t size) { + float beta1_minus = 1 - beta1; + float beta2_minus = 1 - beta2; +#if defined(ENABLE_AVX512) + MS_FLOAT32X16 beta1_16 = MS_MOV512_F32(beta1); + MS_FLOAT32X16 beta2_16 = MS_MOV512_F32(beta2); + MS_FLOAT32X16 beta1_minus_16 = MS_MOV512_F32(beta1_minus); + MS_FLOAT32X16 beta2_minus_16 = MS_MOV512_F32(beta2_minus); + MS_FLOAT32X16 lr_neg_16 = MS_MOV512_F32(-lr); + MS_FLOAT32X16 epsilon_16 = MS_MOV512_F32(epsilon); + MS_FLOAT32X16 decay_16 = MS_MOV512_F32(*decay); +#endif +#if defined(ENABLE_NEON) + MS_FLOAT32X4 epsilon_4 = MS_MOVQ_F32(epsilon); + float lr_neg = -lr; +#endif + + auto task = [&](size_t start, size_t end) { + size_t i = start; +#if defined(ENABLE_AVX512) + if (end >= MS_AVX512_WIDTH) { + for (; i <= end - MS_AVX512_WIDTH; i += MS_AVX512_WIDTH) { + MS_FLOAT32X16 var_16 = MS_LD512_F32(var + i); + MS_FLOAT32X16 m_16 = MS_LD512_F32(m + i); + MS_FLOAT32X16 v_16 = MS_LD512_F32(v + i); + MS_FLOAT32X16 g_16 = MS_LD512_F32(gradient + i); + m_16 = MS_MUL512_F32(m_16, beta1_16); + m_16 = MS_FMA512_F32(g_16, beta1_minus_16, m_16); + v_16 = MS_MUL512_F32(v_16, beta2_16); + v_16 = MS_MUL512_F32(g_16, g_16); + v_16 = MS_FMA512_F32(g_16, beta2_minus_16, v_16); + g_16 = MS_SQRT512_F32(v_16); + g_16 = MS_DIV512_F32(m_16, MS_ADD512_F32(g_16, epsilon_16)); + g_16 = MS_FMA512_F32(var_16, decay_16, g_16); + var_16 = MS_FMA512_F32(g_16, lr_neg_16, var_16); + MS_ST512_F32(var + i, var_16); + MS_ST512_F32(m + i, m_16); + MS_ST512_F32(v + i, v_16); + } + } +#endif +#if defined(ENABLE_NEON) + if (end >= MS_NEON_WIDTH) { + for (; i <= end - MS_NEON_WIDTH; i += MS_NEON_WIDTH) { + MS_FLOAT32X4 var_4 = MS_LDQ_F32(var + i); + MS_FLOAT32X4 m_4 = MS_LDQ_F32(m + i); + MS_FLOAT32X4 v_4 = MS_LDQ_F32(v + i); + MS_FLOAT32X4 g_4 = MS_LDQ_F32(gradient + i); + m_4 = MS_MULQ_N_F32(m_4, beta1); + m_4 = MS_MLAQ_N_F32(m_4, g_4, beta1_minus); + v_4 = MS_MULQ_N_F32(v_4, beta2); + g_4 = MS_MULQ_F32(g_4, g_4); + v_4 = MS_MLAQ_N_F32(v_4, g_4, beta2_minus); + g_4 = MS_SQRT_F32(v_4); + g_4 = MS_DIVQ_F32(m_4, MS_ADDQ_F32(g_4, epsilon_4)); + g_4 = MS_MLAQ_N_F32(g_4, var_4, *decay); + var_4 = MS_MLAQ_N_F32(var_4, g_4, lr_neg); + MS_STQ_F32(var + i, var_4); + MS_STQ_F32(m + i, m_4); + MS_STQ_F32(v + i, v_4); + } + } +#endif + for (; i < end; i++) { + m[i] += (gradient[i] - m[i]) * beta1_minus; + v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus; + T update = m[i] / (std::sqrt(v[i]) + epsilon); + update += decay[0] * var[i]; + var[i] -= lr * update; + } + }; + CPUKernelUtils::ParallelFor(task, size); +} + +void AdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 9) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 3) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AdamWeightDecay needs 3 outputs."; + } +} + +bool AdamWeightDecayCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() != 9) { + MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs."; + } + if (outputs.size() != 3) { + MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs."; + } + if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[2]->size || inputs[0]->size != inputs[8]->size) { + MS_LOG(EXCEPTION) << "Error input data size!"; + } + size_t f_size = sizeof(float); + if (inputs[3]->size != f_size || inputs[4]->size != f_size || inputs[5]->size != f_size || + inputs[6]->size != f_size || inputs[7]->size != f_size) { + MS_LOG(EXCEPTION) << "The attribute beta, lr and epsilon must be float!"; + } + auto var = reinterpret_cast(inputs[0]->addr); + auto m = reinterpret_cast(inputs[1]->addr); + auto v = reinterpret_cast(inputs[2]->addr); + float lr = reinterpret_cast(inputs[3]->addr)[0]; + float beta1 = reinterpret_cast(inputs[4]->addr)[0]; + float beta2 = reinterpret_cast(inputs[5]->addr)[0]; + float epsilon = reinterpret_cast(inputs[6]->addr)[0]; + auto decay = reinterpret_cast(inputs[7]->addr); + auto gradient = reinterpret_cast(inputs[8]->addr); + + // multithreading + size_t lens = inputs[0]->size > 0 ? static_cast(inputs[0]->size / sizeof(float)) : 1; + LaunchAdamWeightDecay(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, lens); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h new file mode 100644 index 00000000000..7b04c2b2313 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +#ifdef ENABLE_NEON +#include +#endif + +#if defined(ENABLE_AVX512) +#include +#endif + +#ifdef ENABLE_NEON +#define MS_FLOAT32X4 float32x4_t +#define MS_LDQ_F32 vld1q_f32 +#define MS_MOVQ_F32 vmovq_n_f32 +#define MS_STQ_F32 vst1q_f32 +#define MS_ADDQ_F32(src1, src2) vaddq_f32(src1, src2) +#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2) +#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2) +#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) +#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3) +#define MS_MLAQ_N_F32(src1, src2, src3) vmlaq_n_f32(src1, src2, src3) +#define MS_SQRT_F32(src) vsqrtq_f32(src) +#define MS_CAST_F32_F16(src) vreinterpretq_f32_f16(src) +#define MS_NEON_WIDTH 4 +#endif + +#if defined(ENABLE_AVX512) +#define MS_FLOAT32X16 __m512 +#define MS_LD512_F32 _mm512_loadu_ps +#define MS_ST512_F32 _mm512_storeu_ps +#define MS_MOV512_F32 _mm512_set1_ps +#define MS_ADD512_F32(src1, src2) _mm512_add_ps(src1, src2) +#define MS_MUL512_F32(src1, src2) _mm512_mul_ps(src1, src2) +#define MS_DIV512_F32(src1, src2) _mm512_div_ps(src1, src2) +#define MS_FMA512_F32(src1, src2, src3) _mm512_fmadd_ps(src1, src2, src3) +#define MS_SQRT512_F32(src) _mm512_sqrt_ps(src) +#define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src) +#define MS_AVX512_WIDTH 16 +#endif + +namespace mindspore { +namespace kernel { +class AdamWeightDecayCPUKernel : public CPUKernel { + public: + AdamWeightDecayCPUKernel() = default; + ~AdamWeightDecayCPUKernel() override = default; + template + void LaunchAdamWeightDecay(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, T *decay, + const T *gradient, size_t size); + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: +}; + +MS_REG_CPU_KERNEL(AdamWeightDecay, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + AdamWeightDecayCPUKernel) +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_ diff --git a/mindspore/ops/operations/inner_ops.py b/mindspore/ops/operations/inner_ops.py index 59fe8fb0179..357f799cf1f 100644 --- a/mindspore/ops/operations/inner_ops.py +++ b/mindspore/ops/operations/inner_ops.py @@ -495,7 +495,7 @@ class AdamWeightDecay(PrimitiveWithInfer): - **v** (Tensor) - The same shape and data type as `v`. Supported Platforms: - ``GPU`` + ``GPU`` ``CPU`` Examples: >>> import numpy as np diff --git a/model_zoo/official/nlp/gpt/src/adam.py b/model_zoo/official/nlp/gpt/src/adam.py new file mode 100644 index 00000000000..3403ea95ec6 --- /dev/null +++ b/model_zoo/official/nlp/gpt/src/adam.py @@ -0,0 +1,162 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""AdamWeightDecay, a customized Adam for pangu1. Input: gradient.""" +import numpy as np + +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.common.tensor import Tensor +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel +from mindspore.nn.optim.optimizer import Optimizer + +_adam_opt = C.MultitypeFuncGraph("adam_opt") +_scaler_one = Tensor(1, mstype.int32) +_scaler_ten = Tensor(10, mstype.float32) + + +@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", + "Tensor", "Bool", "Bool") +def _update_run_kernel(opt, beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter): + """ + Update parameters by AdamWeightDecay op. + """ + if optim_filter: + op_cast = P.Cast() + gradient_fp32 = op_cast(gradient, mstype.float32) + if decay_flags: + next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(weight_decay, mstype.float32), gradient_fp32) + else: + next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(0.0, mstype.float32), gradient_fp32) + return next_param + return gradient + + +def _check_param_value(beta1, beta2, eps, prim_name): + """Check the type of inputs.""" + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) + validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) + validator.check_positive_float(eps, "eps", prim_name) + + +class AdamWeightDecayOp(Optimizer): + """ + Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops. + + Note: + When separating parameter groups, the weight decay in each group will be applied on the parameters if the + weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied + on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. + + To improve parameter groups performance, the customized order of parameters can be supported. + + Args: + params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, + the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", + "lr", "weight_decay" and "order_params" are the keys can be parsed. + + - params: Required. The value must be a list of `Parameter`. + + - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. + If not, the `learning_rate` in the API will be used. + + - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay + will be used. If not, the `weight_decay` in the API will be used. + + - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and + the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters + which in the 'order_params' must be in one of group parameters. + + learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. + When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then + the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, + use dynamic learning rate, the i-th learning rate will be calculated during the process of training + according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero + dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be + equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. + Default: 1e-3. + beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. + Should be in range (0.0, 1.0). + beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. + Should be in range (0.0, 1.0). + eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. + Should be greater than 0. + weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. + + Outputs: + tuple[bool], all elements are True. + + Supported Platforms: + ``CPU`` + + Examples: + >>> net = Net() + >>> #1) All parameters use the same learning rate and weight decay + >>> optim = AdamWeightDecayOp(params=net.trainable_params()) + >>> + >>> #2) Use parameter groups and set different values + >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, + ... {'params': no_conv_params, 'lr': 0.01}, + ... {'order_params': net.trainable_params()}] + >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0) + >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. + >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. + >>> + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> model = Model(net, loss_fn=loss, optimizer=optim) + """ + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): + super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay) + _check_param_value(beta1, beta2, eps, self.cls_name) + self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) + self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) + self.eps = Tensor(np.array([eps]).astype(np.float32)) + self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') + self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') + self.hyper_map = C.HyperMap() + self.opt = P.AdamWeightDecay() + self.opt.add_prim_attr("primitive_target", "CPU") + + def construct(self, gradients): + """AdamWeightDecayOp""" + lr = self.get_lr() + if self.is_group: + if self.is_group_lr: + optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps), + lr, self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr), + self.weight_decay, self.parameters, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr, + self.weight_decay), + self.parameters, self.moments1, self.moments2, gradients, self.decay_flags, + self.optim_filter) + if self.use_parallel: + self.broadcast_params(optim_result) + return optim_result diff --git a/tests/st/ops/cpu/test_adam_weight_decay_op.py b/tests/st/ops/cpu/test_adam_weight_decay_op.py new file mode 100644 index 00000000000..2da5f6adbe2 --- /dev/null +++ b/tests/st/ops/cpu/test_adam_weight_decay_op.py @@ -0,0 +1,66 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import Dense +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.ops import operations as P +from model_zoo.official.nlp.gpt.src.adam import AdamWeightDecayOp + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetAdamWeightDecay(nn.Cell): + def __init__(self): + super(NetAdamWeightDecay, self).__init__() + self.batch_size = 1 + self.reshape = P.Reshape() + weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) + self.fc1 = Dense(16, 10, weight_init=weight) + + def construct(self, input_x): + output = self.reshape(input_x, (self.batch_size, -1)) + output = self.fc1(output) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_adam_weight_decay(): + epoch = 3 + net = NetAdamWeightDecay() + optimizer = AdamWeightDecayOp(filter(lambda x: x.requires_grad, + net.get_parameters()), learning_rate=0.01) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell( + net_with_criterion, optimizer) + train_network.set_train() + + losses1 = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape( + 1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses1.append(loss.asnumpy()) + assert losses1[0] > losses1[1] + assert losses1[1] > losses1[2]