forked from mindspore-Ecosystem/mindspore
bugfix bert script
This commit is contained in:
parent
99ffe64bb8
commit
a5ac2427a7
|
@ -50,7 +50,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
|
|||
power=optimizer_cfg.AdamWeightDecay.power)
|
||||
params = net_with_loss.trainable_params()
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
|
||||
|
@ -70,7 +70,9 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
|
|||
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="classifier",
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
|
|||
power=optimizer_cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
|
||||
|
@ -71,7 +71,9 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
|
|||
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ner", directory=save_checkpoint_path, config=ckpt_config)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ner",
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ def run_pretrain():
|
|||
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
||||
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||
parser.add_argument("--save_checkpoint_path", type=str, default=None, help="Save checkpoint path")
|
||||
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
||||
"default is 1000.")
|
||||
|
@ -142,7 +142,7 @@ def run_pretrain():
|
|||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
||||
format(cfg.optimizer))
|
||||
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
|
||||
if args_opt.enable_save_ckpt == "true":
|
||||
if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck)
|
||||
|
|
|
@ -52,7 +52,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
|
|||
power=optimizer_cfg.AdamWeightDecay.power)
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0}]
|
||||
|
||||
|
@ -72,7 +72,9 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
|
|||
|
||||
# load checkpoint into network
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="squad", directory=save_checkpoint_path, config=ckpt_config)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="squad",
|
||||
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
||||
config=ckpt_config)
|
||||
param_dict = load_checkpoint(load_checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ def run_general_distill():
|
|||
power=common_cfg.AdamWeightDecay.power)
|
||||
params = netwithloss.trainable_params()
|
||||
decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
other_params = list(filter(lambda x: not common_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
|
|
@ -107,7 +107,7 @@ def run_predistill():
|
|||
power=optimizer_cfg.AdamWeightDecay.power)
|
||||
params = netwithloss.trainable_params()
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
@ -165,7 +165,7 @@ def run_task_distill(ckpt_file):
|
|||
power=optimizer_cfg.AdamWeightDecay.power)
|
||||
params = netwithloss.trainable_params()
|
||||
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
other_params = list(filter(lambda x: not optimizer_cfg.AdamWeightDecay.decay_filter(x), params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
|
Loading…
Reference in New Issue