forked from mindspore-Ecosystem/mindspore
parse bert dtype config
This commit is contained in:
parent
1767f3acf7
commit
011eab9f4b
|
@ -112,6 +112,16 @@ def merge(args, cfg):
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def parse_dtype(dtype):
|
||||||
|
if dtype not in ["mstype.float32", "mstype.float16"]:
|
||||||
|
raise ValueError("Not supported dtype")
|
||||||
|
|
||||||
|
if dtype == "mstype.float32":
|
||||||
|
return mstype.float32
|
||||||
|
if dtype == "mstype.float16":
|
||||||
|
return mstype.float16
|
||||||
|
return None
|
||||||
|
|
||||||
def extra_operations(cfg):
|
def extra_operations(cfg):
|
||||||
"""
|
"""
|
||||||
Do extra work on config
|
Do extra work on config
|
||||||
|
@ -125,12 +135,14 @@ def extra_operations(cfg):
|
||||||
if cfg.description == 'run_pretrain':
|
if cfg.description == 'run_pretrain':
|
||||||
cfg.AdamWeightDecay.decay_filter = create_filter_fun(cfg.AdamWeightDecay.decay_filter)
|
cfg.AdamWeightDecay.decay_filter = create_filter_fun(cfg.AdamWeightDecay.decay_filter)
|
||||||
cfg.Lamb.decay_filter = create_filter_fun(cfg.Lamb.decay_filter)
|
cfg.Lamb.decay_filter = create_filter_fun(cfg.Lamb.decay_filter)
|
||||||
cfg.base_net_cfg.dtype = mstype.float32
|
cfg.base_net_cfg.dtype = parse_dtype(cfg.base_net_cfg.dtype)
|
||||||
cfg.base_net_cfg.compute_type = mstype.float16
|
cfg.base_net_cfg.compute_type = parse_dtype(cfg.base_net_cfg.compute_type)
|
||||||
cfg.nezha_net_cfg.dtype = mstype.float32
|
cfg.nezha_net_cfg.dtype = parse_dtype(cfg.nezha_net_cfg.dtype)
|
||||||
cfg.nezha_net_cfg.compute_type = mstype.float16
|
cfg.nezha_net_cfg.compute_type = parse_dtype(cfg.nezha_net_cfg.compute_type)
|
||||||
cfg.large_net_cfg.dtype = mstype.float32
|
cfg.large_net_cfg.dtype = parse_dtype(cfg.large_net_cfg.dtype)
|
||||||
cfg.large_net_cfg.compute_type = mstype.float16
|
cfg.large_net_cfg.compute_type = parse_dtype(cfg.large_net_cfg.compute_type)
|
||||||
|
cfg.large_acc_net_cfg.dtype = parse_dtype(cfg.large_acc_net_cfg.dtype)
|
||||||
|
cfg.large_acc_net_cfg.compute_type = parse_dtype(cfg.large_acc_net_cfg.compute_type)
|
||||||
if cfg.bert_network == 'base':
|
if cfg.bert_network == 'base':
|
||||||
cfg.batch_size = cfg.base_batch_size
|
cfg.batch_size = cfg.base_batch_size
|
||||||
_bert_net_cfg = cfg.base_net_cfg
|
_bert_net_cfg = cfg.base_net_cfg
|
||||||
|
|
Loading…
Reference in New Issue