forked from mindspore-Ecosystem/mindspore
update control flow int adamweightdecay for bert
This commit is contained in:
parent
da7ce4a2e9
commit
0c97835662
2
akg
2
akg
|
@ -1 +1 @@
|
||||||
Subproject commit ae997e27b217d6c8c7a6cbf6ef812186835d2bdf
|
Subproject commit f4f118a2debd2eacc3f2ab6dc31846f1e04d6e13
|
|
@ -88,7 +88,6 @@ __global__ void IsFinite(const size_t size, const half* input, bool* out) {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void FloatStatus(const size_t size, const T* input, T* out) {
|
__global__ void FloatStatus(const size_t size, const T* input, T* out) {
|
||||||
out[0] = 0;
|
|
||||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||||
if (isinf(input[pos]) != 0 || isnan(input[pos])) {
|
if (isinf(input[pos]) != 0 || isnan(input[pos])) {
|
||||||
out[0] = 1;
|
out[0] = 1;
|
||||||
|
@ -98,7 +97,6 @@ __global__ void FloatStatus(const size_t size, const T* input, T* out) {
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
__global__ void FloatStatus(const size_t size, const half* input, half* out) {
|
__global__ void FloatStatus(const size_t size, const half* input, half* out) {
|
||||||
out[0] = 0;
|
|
||||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||||
if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) {
|
if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) {
|
||||||
out[0] = 1;
|
out[0] = 1;
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
#include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh"
|
#include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -46,6 +47,7 @@ class FloatStatusGpuKernel : public GpuKernel {
|
||||||
switch (kernel_name_) {
|
switch (kernel_name_) {
|
||||||
case OP_STATUS: {
|
case OP_STATUS: {
|
||||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
FillDeviceArray(outputs[0]->size / sizeof(T), output, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,8 @@ from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||||
BertTrainAccumulateStepsWithLossScaleCell
|
BertTrainAccumulateStepsWithLossScaleCell, BertTrainOneStepWithLossScaleCellForAdam, \
|
||||||
|
AdamWeightDecayForBert
|
||||||
from src.dataset import create_bert_dataset
|
from src.dataset import create_bert_dataset
|
||||||
from src.config import cfg, bert_net_cfg
|
from src.config import cfg, bert_net_cfg
|
||||||
from src.utils import LossCallBack, BertLearningRate
|
from src.utils import LossCallBack, BertLearningRate
|
||||||
|
@ -83,8 +84,10 @@ def _get_optimizer(args_opt, network):
|
||||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||||
{'params': other_params, 'weight_decay': 0.0},
|
{'params': other_params, 'weight_decay': 0.0},
|
||||||
{'order_params': params}]
|
{'order_params': params}]
|
||||||
|
if args_opt.enable_lossscale == "true":
|
||||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||||
|
else:
|
||||||
|
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
||||||
format(cfg.optimizer))
|
format(cfg.optimizer))
|
||||||
|
@ -206,8 +209,12 @@ def run_pretrain():
|
||||||
scale_window=cfg.scale_window)
|
scale_window=cfg.scale_window)
|
||||||
|
|
||||||
if args_opt.accumulation_steps <= 1:
|
if args_opt.accumulation_steps <= 1:
|
||||||
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
if cfg.optimizer == 'AdamWeightDecay':
|
||||||
scale_update_cell=update_cell)
|
net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer,
|
||||||
|
scale_update_cell=update_cell)
|
||||||
|
else:
|
||||||
|
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||||
|
scale_update_cell=update_cell)
|
||||||
else:
|
else:
|
||||||
accumulation_steps = args_opt.accumulation_steps
|
accumulation_steps = args_opt.accumulation_steps
|
||||||
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||||
|
|
|
@ -21,13 +21,13 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
|
||||||
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
|
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
|
||||||
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
|
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
|
||||||
SaturateCast, CreateAttentionMaskFromInputMask
|
SaturateCast, CreateAttentionMaskFromInputMask
|
||||||
|
from .adam import AdamWeightDecayForBert
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
|
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
|
||||||
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
|
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
|
||||||
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell",
|
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulateStepsWithLossScaleCell",
|
||||||
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
|
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
|
||||||
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
|
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
|
||||||
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator",
|
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert",
|
||||||
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask"
|
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask"
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,307 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from mindspore._checkparam import Rel
|
||||||
|
from mindspore.nn.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
_adam_opt = C.MultitypeFuncGraph("adam_opt")
|
||||||
|
_scaler_one = Tensor(1, mstype.int32)
|
||||||
|
_scaler_ten = Tensor(10, mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
|
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Bool", "Bool")
|
||||||
|
def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
|
||||||
|
"""
|
||||||
|
Update parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
||||||
|
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
||||||
|
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||||
|
lr (Tensor): Learning rate.
|
||||||
|
overflow (Tensor): Whether overflow occurs.
|
||||||
|
weight_decay (Number): Weight decay. Should be equal to or greater than 0.
|
||||||
|
param (Tensor): Parameters.
|
||||||
|
m (Tensor): m value of parameters.
|
||||||
|
v (Tensor): v value of parameters.
|
||||||
|
gradient (Tensor): Gradient of parameters.
|
||||||
|
decay_flag (bool): Applies weight decay or not.
|
||||||
|
optim_filter (bool): Applies parameter update or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, the new value of v after updating.
|
||||||
|
"""
|
||||||
|
if optim_filter:
|
||||||
|
op_mul = P.Mul()
|
||||||
|
op_square = P.Square()
|
||||||
|
op_sqrt = P.Sqrt()
|
||||||
|
op_cast = P.Cast()
|
||||||
|
op_reshape = P.Reshape()
|
||||||
|
op_shape = P.Shape()
|
||||||
|
op_select = P.Select()
|
||||||
|
|
||||||
|
param_fp32 = op_cast(param, mstype.float32)
|
||||||
|
m_fp32 = op_cast(m, mstype.float32)
|
||||||
|
v_fp32 = op_cast(v, mstype.float32)
|
||||||
|
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||||
|
|
||||||
|
cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
|
||||||
|
next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
|
||||||
|
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
|
||||||
|
|
||||||
|
next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
|
||||||
|
op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
|
||||||
|
|
||||||
|
update = next_m / (eps + op_sqrt(next_v))
|
||||||
|
if decay_flag:
|
||||||
|
update = op_mul(weight_decay, param_fp32) + update
|
||||||
|
|
||||||
|
update_with_lr = op_mul(lr, update)
|
||||||
|
zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
|
||||||
|
next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
|
||||||
|
|
||||||
|
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||||
|
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||||
|
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||||
|
|
||||||
|
return op_cast(next_param, F.dtype(param))
|
||||||
|
return gradient
|
||||||
|
|
||||||
|
|
||||||
|
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||||
|
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
|
||||||
|
beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
|
||||||
|
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
|
||||||
|
success = True
|
||||||
|
indices = gradient.indices
|
||||||
|
values = gradient.values
|
||||||
|
if ps_parameter and not cache_enable:
|
||||||
|
op_shape = P.Shape()
|
||||||
|
shapes = (op_shape(param), op_shape(m), op_shape(v),
|
||||||
|
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
|
||||||
|
op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
|
||||||
|
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
|
||||||
|
eps, values, indices), shapes), param))
|
||||||
|
return success
|
||||||
|
|
||||||
|
if not target:
|
||||||
|
success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
|
||||||
|
eps, values, indices))
|
||||||
|
else:
|
||||||
|
op_mul = P.Mul()
|
||||||
|
op_square = P.Square()
|
||||||
|
op_sqrt = P.Sqrt()
|
||||||
|
scatter_add = P.ScatterAdd(use_locking)
|
||||||
|
|
||||||
|
assign_m = F.assign(m, op_mul(beta1, m))
|
||||||
|
assign_v = F.assign(v, op_mul(beta2, v))
|
||||||
|
|
||||||
|
grad_indices = gradient.indices
|
||||||
|
grad_value = gradient.values
|
||||||
|
|
||||||
|
next_m = scatter_add(m,
|
||||||
|
grad_indices,
|
||||||
|
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||||
|
|
||||||
|
next_v = scatter_add(v,
|
||||||
|
grad_indices,
|
||||||
|
op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
|
||||||
|
|
||||||
|
if use_nesterov:
|
||||||
|
m_temp = next_m * _scaler_ten
|
||||||
|
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
|
||||||
|
div_value = scatter_add(m,
|
||||||
|
op_mul(grad_indices, _scaler_one),
|
||||||
|
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||||
|
param_update = div_value / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
|
m_recover = F.assign(m, m_temp / _scaler_ten)
|
||||||
|
|
||||||
|
F.control_depend(m_temp, assign_m_nesterov)
|
||||||
|
F.control_depend(assign_m_nesterov, div_value)
|
||||||
|
F.control_depend(param_update, m_recover)
|
||||||
|
else:
|
||||||
|
param_update = next_m / (op_sqrt(next_v) + eps)
|
||||||
|
|
||||||
|
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
|
||||||
|
|
||||||
|
next_param = param - lr_t * param_update
|
||||||
|
|
||||||
|
F.control_depend(assign_m, next_m)
|
||||||
|
F.control_depend(assign_v, next_v)
|
||||||
|
|
||||||
|
success = F.depend(success, F.assign(param, next_param))
|
||||||
|
success = F.depend(success, F.assign(m, next_m))
|
||||||
|
success = F.depend(success, F.assign(v, next_v))
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
|
||||||
|
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
|
||||||
|
beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
|
||||||
|
moment1, moment2, ps_parameter, cache_enable):
|
||||||
|
"""Apply adam optimizer to the weight parameter using Tensor."""
|
||||||
|
success = True
|
||||||
|
if ps_parameter and not cache_enable:
|
||||||
|
op_shape = P.Shape()
|
||||||
|
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
|
||||||
|
(op_shape(param), op_shape(moment1), op_shape(moment2))), param))
|
||||||
|
else:
|
||||||
|
success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
|
||||||
|
eps, gradient))
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||||
|
"Tensor", "Tensor")
|
||||||
|
def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
|
||||||
|
"""Apply AdamOffload optimizer to the weight parameter using Tensor."""
|
||||||
|
success = True
|
||||||
|
delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
|
||||||
|
success = F.depend(success, F.assign_add(param, delat_param))
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
def _check_param_value(beta1, beta2, eps, prim_name):
|
||||||
|
"""Check the type of inputs."""
|
||||||
|
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||||
|
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||||
|
validator.check_value_type("eps", eps, [float], prim_name)
|
||||||
|
validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
|
||||||
|
validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
|
||||||
|
validator.check_positive_float(eps, "eps", prim_name)
|
||||||
|
|
||||||
|
class AdamWeightDecayForBert(Optimizer):
|
||||||
|
"""
|
||||||
|
Implements the Adam algorithm to fix the weight decay.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||||
|
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
||||||
|
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
||||||
|
|
||||||
|
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
|
||||||
|
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
|
||||||
|
"lr", "weight_decay" and "order_params" are the keys can be parsed.
|
||||||
|
|
||||||
|
- params: Required. The value must be a list of `Parameter`.
|
||||||
|
|
||||||
|
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
|
||||||
|
If not, the `learning_rate` in the API will be used.
|
||||||
|
|
||||||
|
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
|
||||||
|
will be used. If not, the `weight_decay` in the API will be used.
|
||||||
|
|
||||||
|
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
|
||||||
|
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
|
||||||
|
which in the 'order_params' must be in one of group parameters.
|
||||||
|
|
||||||
|
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
|
||||||
|
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
|
||||||
|
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
|
||||||
|
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
|
||||||
|
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
|
||||||
|
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
|
||||||
|
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
|
||||||
|
Default: 1e-3.
|
||||||
|
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
|
||||||
|
Should be in range (0.0, 1.0).
|
||||||
|
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
|
||||||
|
Should be in range (0.0, 1.0).
|
||||||
|
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
|
||||||
|
Should be greater than 0.
|
||||||
|
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||||
|
- **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
tuple[bool], all elements are True.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> #1) All parameters use the same learning rate and weight decay
|
||||||
|
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
|
||||||
|
>>>
|
||||||
|
>>> #2) Use parameter groups and set different values
|
||||||
|
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
|
||||||
|
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
|
||||||
|
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
|
||||||
|
... {'params': no_conv_params, 'lr': 0.01},
|
||||||
|
... {'order_params': net.trainable_params()}]
|
||||||
|
>>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
|
||||||
|
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
|
||||||
|
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
|
||||||
|
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
|
||||||
|
>>>
|
||||||
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||||
|
"""
|
||||||
|
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
|
||||||
|
super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
|
||||||
|
_check_param_value(beta1, beta2, eps, self.cls_name)
|
||||||
|
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
|
||||||
|
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
|
||||||
|
self.eps = Tensor(np.array([eps]).astype(np.float32))
|
||||||
|
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||||
|
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.op_select = P.Select()
|
||||||
|
self.op_cast = P.Cast()
|
||||||
|
self.op_reshape = P.Reshape()
|
||||||
|
self.op_shape = P.Shape()
|
||||||
|
|
||||||
|
def construct(self, gradients, overflow):
|
||||||
|
"""AdamWeightDecayForBert"""
|
||||||
|
lr = self.get_lr()
|
||||||
|
cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
|
||||||
|
self.op_reshape(overflow, (())), mstype.bool_)
|
||||||
|
beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
|
||||||
|
beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
|
||||||
|
if self.is_group:
|
||||||
|
if self.is_group_lr:
|
||||||
|
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
|
||||||
|
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||||
|
gradients, self.decay_flags, self.optim_filter)
|
||||||
|
else:
|
||||||
|
optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
|
||||||
|
self.weight_decay, self.parameters, self.moments1, self.moments2,
|
||||||
|
gradients, self.decay_flags, self.optim_filter)
|
||||||
|
else:
|
||||||
|
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
|
||||||
|
self.parameters, self.moments1, self.moments2,
|
||||||
|
gradients, self.decay_flags, self.optim_filter)
|
||||||
|
if self.use_parallel:
|
||||||
|
self.broadcast_params(optim_result)
|
||||||
|
return optim_result
|
|
@ -440,6 +440,120 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
ret = (loss, cond, scaling_sens)
|
ret = (loss, cond, scaling_sens)
|
||||||
return F.depend(ret, succ)
|
return F.depend(ret, succ)
|
||||||
|
|
||||||
|
class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell):
|
||||||
|
"""
|
||||||
|
Encapsulation class of bert network training.
|
||||||
|
|
||||||
|
Append an optimizer to the training network after that the construct
|
||||||
|
function can be called to create the backward graph.
|
||||||
|
Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow
|
||||||
|
condition as input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network. Note that loss function should have been added.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the weights.
|
||||||
|
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||||
|
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.network.set_grad()
|
||||||
|
self.weights = optimizer.parameters
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.grad = C.GradOperation(get_by_list=True,
|
||||||
|
sens_param=True)
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.allreduce = P.AllReduce()
|
||||||
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||||
|
self.reducer_flag = True
|
||||||
|
self.grad_reducer = F.identity
|
||||||
|
self.degree = 1
|
||||||
|
if self.reducer_flag:
|
||||||
|
self.degree = get_group_size()
|
||||||
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||||
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
if context.get_context("device_target") == "GPU":
|
||||||
|
self.gpu_target = True
|
||||||
|
self.float_status = P.FloatStatus()
|
||||||
|
self.addn = P.AddN()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
else:
|
||||||
|
self.gpu_target = False
|
||||||
|
self.alloc_status = P.NPUAllocFloatStatus()
|
||||||
|
self.get_status = P.NPUGetFloatStatus()
|
||||||
|
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||||
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||||
|
self.base = Tensor(1, mstype.float32)
|
||||||
|
self.less_equal = P.LessEqual()
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.loss_scale = None
|
||||||
|
self.loss_scaling_manager = scale_update_cell
|
||||||
|
if scale_update_cell:
|
||||||
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||||
|
|
||||||
|
@C.add_flags(has_effect=True)
|
||||||
|
def construct(self,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights,
|
||||||
|
sens=None):
|
||||||
|
"""Defines the computation performed."""
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights)
|
||||||
|
if sens is None:
|
||||||
|
scaling_sens = self.loss_scale
|
||||||
|
else:
|
||||||
|
scaling_sens = sens
|
||||||
|
init = False
|
||||||
|
if not self.gpu_target:
|
||||||
|
# alloc status and clear should be right before gradoperation
|
||||||
|
init = self.alloc_status()
|
||||||
|
self.clear_before_grad(init)
|
||||||
|
grads = self.grad(self.network, weights)(input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights,
|
||||||
|
self.cast(scaling_sens,
|
||||||
|
mstype.float32))
|
||||||
|
# apply grad reducer on grads
|
||||||
|
grads = self.grad_reducer(grads)
|
||||||
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||||
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
|
if not self.gpu_target:
|
||||||
|
self.get_status(init)
|
||||||
|
flag_sum = self.reduce_sum(init, (0,))
|
||||||
|
else:
|
||||||
|
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||||
|
flag_sum = self.addn(flag_sum)
|
||||||
|
flag_sum = self.reshape(flag_sum, (()))
|
||||||
|
if self.is_distributed:
|
||||||
|
# sum overflow flag over devices
|
||||||
|
flag_reduce = self.allreduce(flag_sum)
|
||||||
|
cond = self.less_equal(self.base, flag_reduce)
|
||||||
|
else:
|
||||||
|
cond = self.less_equal(self.base, flag_sum)
|
||||||
|
overflow = cond
|
||||||
|
if self.loss_scaling_manager is not None:
|
||||||
|
overflow = self.loss_scaling_manager(scaling_sens, cond)
|
||||||
|
succ = self.optimizer(grads, overflow)
|
||||||
|
ret = (loss, cond, scaling_sens)
|
||||||
|
return F.depend(ret, succ)
|
||||||
|
|
||||||
cast = P.Cast()
|
cast = P.Cast()
|
||||||
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
||||||
|
|
Loading…
Reference in New Issue