forked from mindspore-Ecosystem/mindspore
fix accurancy lower then 92
This commit is contained in:
parent
6a7b974346
commit
8e2bb7a85c
|
@ -19,7 +19,9 @@ from easydict import EasyDict as edict
|
|||
|
||||
cifar_cfg = edict({
|
||||
'num_classes': 10,
|
||||
'lr_init': 0.05,
|
||||
'lr_init': 0.01,
|
||||
'lr_max': 0.1,
|
||||
'warmup_epochs': 5,
|
||||
'batch_size': 64,
|
||||
'epoch_size': 70,
|
||||
'momentum': 0.9,
|
||||
|
|
|
@ -38,20 +38,25 @@ random.seed(1)
|
|||
np.random.seed(1)
|
||||
|
||||
|
||||
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
|
||||
def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
|
||||
"""Set learning rate."""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr_each_step.append(lr_max)
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr_each_step.append(lr_max * 0.1)
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr_each_step.append(lr_max * 0.01)
|
||||
if i < warmup_steps:
|
||||
lr_value = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
lr_each_step.append(lr_max * 0.001)
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr_value = float(lr_max) * base * base
|
||||
if lr_value < 0.0:
|
||||
lr_value = 0.0
|
||||
lr_each_step.append(lr_value)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
@ -86,7 +91,8 @@ if __name__ == '__main__':
|
|||
if args_opt.pre_trained:
|
||||
load_param_into_net(net, load_checkpoint(args_opt.pre_trained))
|
||||
|
||||
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
|
||||
lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs,
|
||||
total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum,
|
||||
weight_decay=cfg.weight_decay)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
||||
|
|
Loading…
Reference in New Issue