forked from mindspore-Ecosystem/mindspore
!17474 [GraphKernel] bert-crf enable graph kernel with AdamWeightDecay on GPU.
From: @chenlei_autodiff Reviewed-by: @ckey_dou,@gaoxiong1,@gaoxiong1,@ckey_dou Signed-off-by: @gaoxiong1,@ckey_dou
This commit is contained in:
commit
1635ba91b1
|
@ -214,7 +214,7 @@ def run_ner():
|
||||||
if bert_net_cfg.compute_type != mstype.float32:
|
if bert_net_cfg.compute_type != mstype.float32:
|
||||||
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
||||||
bert_net_cfg.compute_type = mstype.float32
|
bert_net_cfg.compute_type = mstype.float32
|
||||||
if optimizer_cfg.optimizer == 'AdamWeightDecay' and args_opt.use_crf.lower() == "false":
|
if optimizer_cfg.optimizer == 'AdamWeightDecay':
|
||||||
context.set_context(enable_graph_kernel=True)
|
context.set_context(enable_graph_kernel=True)
|
||||||
else:
|
else:
|
||||||
raise Exception("Target error, GPU or Ascend is supported.")
|
raise Exception("Target error, GPU or Ascend is supported.")
|
||||||
|
|
Loading…
Reference in New Issue