forked from mindspore-Ecosystem/mindspore
parse bert dtype config
This commit is contained in:
parent
1767f3acf7
commit
011eab9f4b
|
@ -112,6 +112,16 @@ def merge(args, cfg):
|
|||
return cfg
|
||||
|
||||
|
||||
def parse_dtype(dtype):
|
||||
if dtype not in ["mstype.float32", "mstype.float16"]:
|
||||
raise ValueError("Not supported dtype")
|
||||
|
||||
if dtype == "mstype.float32":
|
||||
return mstype.float32
|
||||
if dtype == "mstype.float16":
|
||||
return mstype.float16
|
||||
return None
|
||||
|
||||
def extra_operations(cfg):
|
||||
"""
|
||||
Do extra work on config
|
||||
|
@ -125,12 +135,14 @@ def extra_operations(cfg):
|
|||
if cfg.description == 'run_pretrain':
|
||||
cfg.AdamWeightDecay.decay_filter = create_filter_fun(cfg.AdamWeightDecay.decay_filter)
|
||||
cfg.Lamb.decay_filter = create_filter_fun(cfg.Lamb.decay_filter)
|
||||
cfg.base_net_cfg.dtype = mstype.float32
|
||||
cfg.base_net_cfg.compute_type = mstype.float16
|
||||
cfg.nezha_net_cfg.dtype = mstype.float32
|
||||
cfg.nezha_net_cfg.compute_type = mstype.float16
|
||||
cfg.large_net_cfg.dtype = mstype.float32
|
||||
cfg.large_net_cfg.compute_type = mstype.float16
|
||||
cfg.base_net_cfg.dtype = parse_dtype(cfg.base_net_cfg.dtype)
|
||||
cfg.base_net_cfg.compute_type = parse_dtype(cfg.base_net_cfg.compute_type)
|
||||
cfg.nezha_net_cfg.dtype = parse_dtype(cfg.nezha_net_cfg.dtype)
|
||||
cfg.nezha_net_cfg.compute_type = parse_dtype(cfg.nezha_net_cfg.compute_type)
|
||||
cfg.large_net_cfg.dtype = parse_dtype(cfg.large_net_cfg.dtype)
|
||||
cfg.large_net_cfg.compute_type = parse_dtype(cfg.large_net_cfg.compute_type)
|
||||
cfg.large_acc_net_cfg.dtype = parse_dtype(cfg.large_acc_net_cfg.dtype)
|
||||
cfg.large_acc_net_cfg.compute_type = parse_dtype(cfg.large_acc_net_cfg.compute_type)
|
||||
if cfg.bert_network == 'base':
|
||||
cfg.batch_size = cfg.base_batch_size
|
||||
_bert_net_cfg = cfg.base_net_cfg
|
||||
|
|
Loading…
Reference in New Issue