add warmup_steps in AdamWeightDecayDynamicLR optimizer

This commit is contained in:
yoonlee666 2020-05-06 17:34:27 +08:00
parent 3d3b9d5474
commit eb3f70a0c7
1 changed files with 14 additions and 1 deletions

View File

@ -327,12 +327,17 @@ class AdamWeightDecayDynamicLR(Optimizer):
beta2=0.999,
eps=1e-6,
weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name,
warmup_steps=0):
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
_check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name)
# turn them to scalar when me support scalar/tensor mix operations
self.global_step = Parameter(initializer(0, [1]), name="global_step")
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32))
self.end_learning_rate = Tensor(np.array([end_learning_rate]).astype(np.float32))
self.diff_learning_rate = Tensor(np.array([learning_rate - end_learning_rate]).astype(np.float32))
@ -348,12 +353,20 @@ class AdamWeightDecayDynamicLR(Optimizer):
self.hyper_map = C.HyperMap()
self.min = P.Minimum()
self.pow = P.Pow()
self.greater = P.Greater()
self.one = Tensor(np.array([1.0]).astype(np.float32))
self.cast = P.Cast()
self.start_learning_rate = Tensor(np.array([learning_rate]).astype(np.float32))
def construct(self, gradients):
step = self.min(self.global_step, self.decay_steps)
p = step / self.decay_steps
lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate
if self.warmup_flag:
warmup_percent = self.global_step / self.warmup_steps
warmup_lr = self.start_learning_rate * warmup_percent
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)