add AdamWeightDecay CPU op and optimize with SIMD intinsics
This commit is contained in:
parent
81cd26bdc8
commit
0a920057a5
6
build.sh
6
build.sh
|
@ -119,6 +119,7 @@ checkopts()
|
|||
ANDROID_STL="c++_shared"
|
||||
ENABLE_MAKE_CLEAN="off"
|
||||
X86_64_SIMD="off"
|
||||
ARM_SIMD="off"
|
||||
DEVICE_VERSION=""
|
||||
DEVICE=""
|
||||
ENABLE_NPU="off"
|
||||
|
@ -331,6 +332,9 @@ checkopts()
|
|||
if [[ "$OPTARG" == "sse" || "$OPTARG" == "avx" ]]; then
|
||||
X86_64_SIMD="$OPTARG"
|
||||
fi
|
||||
if [[ "$OPTARG" == "neon" ]]; then
|
||||
ARM_SIMD="$OPTARG"
|
||||
fi
|
||||
;;
|
||||
H)
|
||||
check_on_off $OPTARG H
|
||||
|
@ -474,7 +478,7 @@ build_mindspore()
|
|||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GPU=ON -DUSE_CUDA=ON -DCUDA_PATH=$CUDA_PATH -DMS_REQUIRE_CUDA_VERSION=${CUDA_VERSION}"
|
||||
fi
|
||||
if [[ "X$ENABLE_CPU" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPU=ON -DX86_64_SIMD=${X86_64_SIMD}"
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPU=ON -DX86_64_SIMD=${X86_64_SIMD} -DARM_SIMD=${ARM_SIMD}"
|
||||
fi
|
||||
if [[ "X$COMPILE_MINDDATA" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_MINDDATA=ON"
|
||||
|
|
|
@ -61,6 +61,18 @@ if(ENABLE_CPU)
|
|||
message("not compiled quantum kernel_compiler")
|
||||
set(QUANTUM_SRC_LIST "")
|
||||
endif()
|
||||
|
||||
if("${ARM_SIMD}" STREQUAL "neon")
|
||||
set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc")
|
||||
add_compile_definitions(ENABLE_NEON)
|
||||
set_property(SOURCE ${CPU_SIMD_SRC} PROPERTY COMPILE_OPTIONS -O3 -ffast-math)
|
||||
endif()
|
||||
|
||||
if("${X86_64_SIMD}" STREQUAL "avx")
|
||||
set(CPU_SIMD_SRC "${CMAKE_CURRENT_SOURCE_DIR}/cpu/adam_weight_decay_cpu_kernel.cc")
|
||||
add_compile_definitions(ENABLE_AVX512)
|
||||
set_property(SOURCE ${CPU_SIMD_SRC} PROPERTY COMPILE_OPTIONS -O3 -fopenmp -mavx512f -ffast-math)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/adam_weight_decay_cpu_kernel.h"
|
||||
|
||||
#include <cmath>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(T *var, T *m, T *v, float lr, float beta1, float beta2,
|
||||
float epsilon, T *decay, const T *gradient, size_t size) {
|
||||
float beta1_minus = 1 - beta1;
|
||||
float beta2_minus = 1 - beta2;
|
||||
#if defined(ENABLE_AVX512)
|
||||
MS_FLOAT32X16 beta1_16 = MS_MOV512_F32(beta1);
|
||||
MS_FLOAT32X16 beta2_16 = MS_MOV512_F32(beta2);
|
||||
MS_FLOAT32X16 beta1_minus_16 = MS_MOV512_F32(beta1_minus);
|
||||
MS_FLOAT32X16 beta2_minus_16 = MS_MOV512_F32(beta2_minus);
|
||||
MS_FLOAT32X16 lr_neg_16 = MS_MOV512_F32(-lr);
|
||||
MS_FLOAT32X16 epsilon_16 = MS_MOV512_F32(epsilon);
|
||||
MS_FLOAT32X16 decay_16 = MS_MOV512_F32(*decay);
|
||||
#endif
|
||||
#if defined(ENABLE_NEON)
|
||||
MS_FLOAT32X4 epsilon_4 = MS_MOVQ_F32(epsilon);
|
||||
float lr_neg = -lr;
|
||||
#endif
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
size_t i = start;
|
||||
#if defined(ENABLE_AVX512)
|
||||
if (end >= MS_AVX512_WIDTH) {
|
||||
for (; i <= end - MS_AVX512_WIDTH; i += MS_AVX512_WIDTH) {
|
||||
MS_FLOAT32X16 var_16 = MS_LD512_F32(var + i);
|
||||
MS_FLOAT32X16 m_16 = MS_LD512_F32(m + i);
|
||||
MS_FLOAT32X16 v_16 = MS_LD512_F32(v + i);
|
||||
MS_FLOAT32X16 g_16 = MS_LD512_F32(gradient + i);
|
||||
m_16 = MS_MUL512_F32(m_16, beta1_16);
|
||||
m_16 = MS_FMA512_F32(g_16, beta1_minus_16, m_16);
|
||||
v_16 = MS_MUL512_F32(v_16, beta2_16);
|
||||
v_16 = MS_MUL512_F32(g_16, g_16);
|
||||
v_16 = MS_FMA512_F32(g_16, beta2_minus_16, v_16);
|
||||
g_16 = MS_SQRT512_F32(v_16);
|
||||
g_16 = MS_DIV512_F32(m_16, MS_ADD512_F32(g_16, epsilon_16));
|
||||
g_16 = MS_FMA512_F32(var_16, decay_16, g_16);
|
||||
var_16 = MS_FMA512_F32(g_16, lr_neg_16, var_16);
|
||||
MS_ST512_F32(var + i, var_16);
|
||||
MS_ST512_F32(m + i, m_16);
|
||||
MS_ST512_F32(v + i, v_16);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(ENABLE_NEON)
|
||||
if (end >= MS_NEON_WIDTH) {
|
||||
for (; i <= end - MS_NEON_WIDTH; i += MS_NEON_WIDTH) {
|
||||
MS_FLOAT32X4 var_4 = MS_LDQ_F32(var + i);
|
||||
MS_FLOAT32X4 m_4 = MS_LDQ_F32(m + i);
|
||||
MS_FLOAT32X4 v_4 = MS_LDQ_F32(v + i);
|
||||
MS_FLOAT32X4 g_4 = MS_LDQ_F32(gradient + i);
|
||||
m_4 = MS_MULQ_N_F32(m_4, beta1);
|
||||
m_4 = MS_MLAQ_N_F32(m_4, g_4, beta1_minus);
|
||||
v_4 = MS_MULQ_N_F32(v_4, beta2);
|
||||
g_4 = MS_MULQ_F32(g_4, g_4);
|
||||
v_4 = MS_MLAQ_N_F32(v_4, g_4, beta2_minus);
|
||||
g_4 = MS_SQRT_F32(v_4);
|
||||
g_4 = MS_DIVQ_F32(m_4, MS_ADDQ_F32(g_4, epsilon_4));
|
||||
g_4 = MS_MLAQ_N_F32(g_4, var_4, *decay);
|
||||
var_4 = MS_MLAQ_N_F32(var_4, g_4, lr_neg);
|
||||
MS_STQ_F32(var + i, var_4);
|
||||
MS_STQ_F32(m + i, m_4);
|
||||
MS_STQ_F32(v + i, v_4);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
for (; i < end; i++) {
|
||||
m[i] += (gradient[i] - m[i]) * beta1_minus;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * beta2_minus;
|
||||
T update = m[i] / (std::sqrt(v[i]) + epsilon);
|
||||
update += decay[0] * var[i];
|
||||
var[i] -= lr * update;
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, size);
|
||||
}
|
||||
|
||||
void AdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 9) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 3) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AdamWeightDecay needs 3 outputs.";
|
||||
}
|
||||
}
|
||||
|
||||
bool AdamWeightDecayCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != 9) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
|
||||
}
|
||||
if (outputs.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
|
||||
}
|
||||
if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[2]->size || inputs[0]->size != inputs[8]->size) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
}
|
||||
size_t f_size = sizeof(float);
|
||||
if (inputs[3]->size != f_size || inputs[4]->size != f_size || inputs[5]->size != f_size ||
|
||||
inputs[6]->size != f_size || inputs[7]->size != f_size) {
|
||||
MS_LOG(EXCEPTION) << "The attribute beta, lr and epsilon must be float!";
|
||||
}
|
||||
auto var = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto m = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
auto v = reinterpret_cast<float *>(inputs[2]->addr);
|
||||
float lr = reinterpret_cast<float *>(inputs[3]->addr)[0];
|
||||
float beta1 = reinterpret_cast<float *>(inputs[4]->addr)[0];
|
||||
float beta2 = reinterpret_cast<float *>(inputs[5]->addr)[0];
|
||||
float epsilon = reinterpret_cast<float *>(inputs[6]->addr)[0];
|
||||
auto decay = reinterpret_cast<float *>(inputs[7]->addr);
|
||||
auto gradient = reinterpret_cast<float *>(inputs[8]->addr);
|
||||
|
||||
// multithreading
|
||||
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||
LaunchAdamWeightDecay<float>(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, lens);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_AVX512)
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#define MS_FLOAT32X4 float32x4_t
|
||||
#define MS_LDQ_F32 vld1q_f32
|
||||
#define MS_MOVQ_F32 vmovq_n_f32
|
||||
#define MS_STQ_F32 vst1q_f32
|
||||
#define MS_ADDQ_F32(src1, src2) vaddq_f32(src1, src2)
|
||||
#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2)
|
||||
#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2)
|
||||
#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
|
||||
#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
|
||||
#define MS_MLAQ_N_F32(src1, src2, src3) vmlaq_n_f32(src1, src2, src3)
|
||||
#define MS_SQRT_F32(src) vsqrtq_f32(src)
|
||||
#define MS_CAST_F32_F16(src) vreinterpretq_f32_f16(src)
|
||||
#define MS_NEON_WIDTH 4
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_AVX512)
|
||||
#define MS_FLOAT32X16 __m512
|
||||
#define MS_LD512_F32 _mm512_loadu_ps
|
||||
#define MS_ST512_F32 _mm512_storeu_ps
|
||||
#define MS_MOV512_F32 _mm512_set1_ps
|
||||
#define MS_ADD512_F32(src1, src2) _mm512_add_ps(src1, src2)
|
||||
#define MS_MUL512_F32(src1, src2) _mm512_mul_ps(src1, src2)
|
||||
#define MS_DIV512_F32(src1, src2) _mm512_div_ps(src1, src2)
|
||||
#define MS_FMA512_F32(src1, src2, src3) _mm512_fmadd_ps(src1, src2, src3)
|
||||
#define MS_SQRT512_F32(src) _mm512_sqrt_ps(src)
|
||||
#define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src)
|
||||
#define MS_AVX512_WIDTH 16
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AdamWeightDecayCPUKernel : public CPUKernel {
|
||||
public:
|
||||
AdamWeightDecayCPUKernel() = default;
|
||||
~AdamWeightDecayCPUKernel() override = default;
|
||||
template <typename T>
|
||||
void LaunchAdamWeightDecay(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, T *decay,
|
||||
const T *gradient, size_t size);
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(AdamWeightDecay,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdamWeightDecayCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADAM_WEIGHT_DECAY_CPU_KERNEL_H_
|
|
@ -495,7 +495,7 @@ class AdamWeightDecay(PrimitiveWithInfer):
|
|||
- **v** (Tensor) - The same shape and data type as `v`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""AdamWeightDecay, a customized Adam for pangu1. Input: gradient."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
|
||||
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||
_scaler_one = Tensor(1, mstype.int32)
|
||||
_scaler_ten = Tensor(10, mstype.float32)
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_kernel(opt, beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
|
||||
"""
|
||||
Update parameters by AdamWeightDecay op.
|
||||
"""
|
||||
if optim_filter:
|
||||
op_cast = P.Cast()
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
if decay_flags:
|
||||
next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(weight_decay, mstype.float32), gradient_fp32)
|
||||
else:
|
||||
next_param = opt(param, m, v, lr, beta1, beta2, eps, F.cast(0.0, mstype.float32), gradient_fp32)
|
||||
return next_param
|
||||
return gradient
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||
validator.check_positive_float(eps, "eps", prim_name)
|
||||
|
||||
|
||||
class AdamWeightDecayOp(Optimizer):
|
||||
"""
|
||||
Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops.
|
||||
|
||||
Note:
|
||||
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
||||
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
||||
|
||||
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||
|
||||
Args:
|
||||
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||
"lr", "weight_decay" and "order_params" are the keys can be parsed.
|
||||
|
||||
- params: Required. The value must be a list of `Parameter`.
|
||||
|
||||
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
|
||||
If not, the `learning_rate` in the API will be used.
|
||||
|
||||
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
|
||||
will be used. If not, the `weight_decay` in the API will be used.
|
||||
|
||||
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
|
||||
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
|
||||
which in the 'order_params' must be in one of group parameters.
|
||||
|
||||
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
|
||||
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
|
||||
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
|
||||
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
|
||||
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
|
||||
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
|
||||
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
|
||||
Default: 1e-3.
|
||||
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
|
||||
Should be in range (0.0, 1.0).
|
||||
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
|
||||
Should be in range (0.0, 1.0).
|
||||
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||
Should be greater than 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[bool], all elements are True.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> #1) All parameters use the same learning rate and weight decay
|
||||
>>> optim = AdamWeightDecayOp(params=net.trainable_params())
|
||||
>>>
|
||||
>>> #2) Use parameter groups and set different values
|
||||
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
|
||||
... {'params': no_conv_params, 'lr': 0.01},
|
||||
... {'order_params': net.trainable_params()}]
|
||||
>>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
|
||||
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
|
||||
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||
>>>
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
"""
|
||||
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||
super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
|
||||
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.AdamWeightDecay()
|
||||
self.opt.add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
def construct(self, gradients):
|
||||
"""AdamWeightDecayOp"""
|
||||
lr = self.get_lr()
|
||||
if self.is_group:
|
||||
if self.is_group_lr:
|
||||
optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps),
|
||||
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr),
|
||||
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
else:
|
||||
optim_result = self.map_(F.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr,
|
||||
self.weight_decay),
|
||||
self.parameters, self.moments1, self.moments2, gradients, self.decay_flags,
|
||||
self.optim_filter)
|
||||
if self.use_parallel:
|
||||
self.broadcast_params(optim_result)
|
||||
return optim_result
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Dense
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.ops import operations as P
|
||||
from model_zoo.official.nlp.gpt.src.adam import AdamWeightDecayOp
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class NetAdamWeightDecay(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetAdamWeightDecay, self).__init__()
|
||||
self.batch_size = 1
|
||||
self.reshape = P.Reshape()
|
||||
weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01)
|
||||
self.fc1 = Dense(16, 10, weight_init=weight)
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.reshape(input_x, (self.batch_size, -1))
|
||||
output = self.fc1(output)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_adam_weight_decay():
|
||||
epoch = 3
|
||||
net = NetAdamWeightDecay()
|
||||
optimizer = AdamWeightDecayOp(filter(lambda x: x.requires_grad,
|
||||
net.get_parameters()), learning_rate=0.01)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(
|
||||
net_with_criterion, optimizer)
|
||||
train_network.set_train()
|
||||
|
||||
losses1 = []
|
||||
for _ in range(epoch):
|
||||
data = Tensor(np.arange(0, 16).reshape(
|
||||
1, 1, 4, 4).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array([0]).astype(np.int32))
|
||||
loss = train_network(data, label)
|
||||
losses1.append(loss.asnumpy())
|
||||
assert losses1[0] > losses1[1]
|
||||
assert losses1[1] > losses1[2]
|
Loading…
Reference in New Issue