bert thor supports lr configuration in config.py
This commit is contained in:
parent
d66d5fcedd
commit
8f8eee4b5e
|
@ -40,6 +40,8 @@ class ConvertNetUntils():
|
|||
if subcell.activation_flag:
|
||||
act_class = subcell.activation.__class__.__name__
|
||||
act_name = act_class.lower()
|
||||
if act_name == "fastgelu":
|
||||
act_name = "fast_gelu"
|
||||
if subcell.out_channels == 1001:
|
||||
new_subcell = nn.Dense_Thor(in_channels=subcell.in_channels,
|
||||
out_channels=subcell.out_channels,
|
||||
|
|
|
@ -18,7 +18,7 @@ network config setting, will be used in train.py and eval.py
|
|||
from easydict import EasyDict as ed
|
||||
# config optimizer for resnet50, imagenet2012. Momentum is default, Thor is optional.
|
||||
cfg = ed({
|
||||
'optimizer': 'Thor',
|
||||
'optimizer': 'Momentum',
|
||||
})
|
||||
|
||||
# config for resent50, cifar10
|
||||
|
|
|
@ -18,7 +18,7 @@ import argparse
|
|||
import ast
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim.momentum import Momentum, THOR
|
||||
from mindspore.nn.optim import Momentum, THOR
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.train_thor import ConvertModelUtils
|
||||
|
|
|
@ -100,8 +100,9 @@ def _get_optimizer(args_opt, network):
|
|||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == "Thor":
|
||||
from src.utils import get_bert_thor_lr, get_bert_thor_damping
|
||||
lr = get_bert_thor_lr()
|
||||
damping = get_bert_thor_damping()
|
||||
lr = get_bert_thor_lr(cfg.Thor.lr_max, cfg.Thor.lr_min, cfg.Thor.lr_power, cfg.Thor.lr_total_steps)
|
||||
damping = get_bert_thor_damping(cfg.Thor.damping_max, cfg.Thor.damping_min, cfg.Thor.damping_power,
|
||||
cfg.Thor.damping_total_steps)
|
||||
split_indices = None
|
||||
if bert_net_cfg.num_hidden_layers == 12:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
|
|
|
@ -49,6 +49,14 @@ cfg = edict({
|
|||
'momentum': 0.9,
|
||||
}),
|
||||
'Thor': edict({
|
||||
'lr_max': 0.0034,
|
||||
'lr_min': 3.244e-5,
|
||||
'lr_power': 1.0,
|
||||
'lr_total_steps': 30000,
|
||||
'damping_max': 5e-2,
|
||||
'damping_min': 1e-6,
|
||||
'damping_power': 1.0,
|
||||
'damping_total_steps': 30000,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
'loss_scale': 1.0,
|
||||
|
|
|
@ -22,7 +22,6 @@ import math
|
|||
import collections
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -107,10 +106,11 @@ class LossCallBack(Callback):
|
|||
percent = 1
|
||||
epoch_num -= 1
|
||||
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)),
|
||||
flush=True)
|
||||
else:
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
str(cb_params.net_outputs)), flush=True)
|
||||
|
||||
def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
|
||||
"""
|
||||
|
@ -220,22 +220,13 @@ def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps
|
|||
return learning_rate
|
||||
|
||||
|
||||
def get_bert_thor_lr():
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=3.244018779068399e-05,
|
||||
lr_max=0.0034022148941459055, warmup_steps=0, total_steps=30000, poly_power=1)
|
||||
else:
|
||||
learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=1.7e-3, warmup_steps=0,
|
||||
total_steps=30000, poly_power=1)
|
||||
|
||||
def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000):
|
||||
learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0,
|
||||
total_steps=lr_total_steps, poly_power=lr_power)
|
||||
return Tensor(learning_rate)
|
||||
|
||||
|
||||
def get_bert_thor_damping():
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=5e-2, warmup_steps=0, total_steps=30000,
|
||||
poly_power=1)
|
||||
else:
|
||||
damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.5e-2, warmup_steps=0,
|
||||
total_steps=30000, poly_power=1)
|
||||
def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000):
|
||||
damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0,
|
||||
total_steps=damping_total_steps, poly_power=damping_power)
|
||||
return Tensor(damping)
|
||||
|
|
Loading…
Reference in New Issue