!10828 fix warning for enable_graph_kernel context in CPU device

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2020-12-30 18:41:13 +08:00 committed by Gitee
commit 76bd0f1245
2 changed files with 3 additions and 3 deletions

View File

@ -277,8 +277,8 @@ class Lamb(Optimizer):
self.global_step = Parameter(initializer(0, [1]), name='global_step') self.global_step = Parameter(initializer(0, [1]), name='global_step')
self.assignadd = P.AssignAdd() self.assignadd = P.AssignAdd()
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.enable_graph_kernel = context.get_context("enable_graph_kernel") and \ self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \
context.get_context("device_target") == "Ascend" context.get_context("enable_graph_kernel")
def construct(self, gradients): def construct(self, gradients):
lr = self.get_lr() lr = self.get_lr()

View File

@ -56,7 +56,7 @@ class _OpSelector:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
_op_type = _OpSelector.DEFAULT_OP_TYPE _op_type = _OpSelector.DEFAULT_OP_TYPE
if context.get_context("enable_graph_kernel"): if context.get_context("device_target") in ['Ascend', 'GPU'] and context.get_context("enable_graph_kernel"):
if _OpSelector.KW_STR in kwargs: if _OpSelector.KW_STR in kwargs:
_op_type = kwargs.get(_OpSelector.KW_STR) _op_type = kwargs.get(_OpSelector.KW_STR)
kwargs.pop(_OpSelector.KW_STR, None) kwargs.pop(_OpSelector.KW_STR, None)