From 011eab9f4b52259bbf8117124445cff77a27b57d Mon Sep 17 00:00:00 2001 From: chenhaozhe Date: Wed, 14 Jul 2021 18:32:51 +0800 Subject: [PATCH] parse bert dtype config --- .../nlp/bert/src/model_utils/config.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/model_zoo/official/nlp/bert/src/model_utils/config.py b/model_zoo/official/nlp/bert/src/model_utils/config.py index 49680622182..d36689c29cd 100644 --- a/model_zoo/official/nlp/bert/src/model_utils/config.py +++ b/model_zoo/official/nlp/bert/src/model_utils/config.py @@ -112,6 +112,16 @@ def merge(args, cfg): return cfg +def parse_dtype(dtype): + if dtype not in ["mstype.float32", "mstype.float16"]: + raise ValueError("Not supported dtype") + + if dtype == "mstype.float32": + return mstype.float32 + if dtype == "mstype.float16": + return mstype.float16 + return None + def extra_operations(cfg): """ Do extra work on config @@ -125,12 +135,14 @@ def extra_operations(cfg): if cfg.description == 'run_pretrain': cfg.AdamWeightDecay.decay_filter = create_filter_fun(cfg.AdamWeightDecay.decay_filter) cfg.Lamb.decay_filter = create_filter_fun(cfg.Lamb.decay_filter) - cfg.base_net_cfg.dtype = mstype.float32 - cfg.base_net_cfg.compute_type = mstype.float16 - cfg.nezha_net_cfg.dtype = mstype.float32 - cfg.nezha_net_cfg.compute_type = mstype.float16 - cfg.large_net_cfg.dtype = mstype.float32 - cfg.large_net_cfg.compute_type = mstype.float16 + cfg.base_net_cfg.dtype = parse_dtype(cfg.base_net_cfg.dtype) + cfg.base_net_cfg.compute_type = parse_dtype(cfg.base_net_cfg.compute_type) + cfg.nezha_net_cfg.dtype = parse_dtype(cfg.nezha_net_cfg.dtype) + cfg.nezha_net_cfg.compute_type = parse_dtype(cfg.nezha_net_cfg.compute_type) + cfg.large_net_cfg.dtype = parse_dtype(cfg.large_net_cfg.dtype) + cfg.large_net_cfg.compute_type = parse_dtype(cfg.large_net_cfg.compute_type) + cfg.large_acc_net_cfg.dtype = parse_dtype(cfg.large_acc_net_cfg.dtype) + cfg.large_acc_net_cfg.compute_type = parse_dtype(cfg.large_acc_net_cfg.compute_type) if cfg.bert_network == 'base': cfg.batch_size = cfg.base_batch_size _bert_net_cfg = cfg.base_net_cfg