add para glabal step to inceptionv3 lr generator

This commit is contained in:
zhouyaqiang 2020-11-02 09:55:21 +08:00
parent 2c91bace34
commit 065e86a4d3
1 changed files with 4 additions and 3 deletions

View File

@ -17,7 +17,7 @@ import math
import numpy as np import numpy as np
def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps): def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps, global_step=0):
""" """
Applies three steps decay to generate learning rate array. Applies three steps decay to generate learning rate array.
@ -45,6 +45,7 @@ def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps):
else: else:
lr = lr_max * 0.001 lr = lr_max * 0.001
lr_each_step.append(lr) lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)[global_step:]
return lr_each_step return lr_each_step
@ -131,7 +132,7 @@ def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
return lr_each_step return lr_each_step
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode, global_step=0):
""" """
generate learning rate array generate learning rate array
@ -150,7 +151,7 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch
total_steps = steps_per_epoch * total_epochs total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps': if lr_decay_mode == 'steps':
lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps) lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps, global_step)
elif lr_decay_mode == 'steps_decay': elif lr_decay_mode == 'steps_decay':
lr_each_step = _generate_exponential_lr(lr_init, lr_max, total_steps, warmup_steps, steps_per_epoch) lr_each_step = _generate_exponential_lr(lr_init, lr_max, total_steps, warmup_steps, steps_per_epoch)
elif lr_decay_mode == 'cosine': elif lr_decay_mode == 'cosine':