fix warning for enable_graph_kernel context in CPU device

This commit is contained in:
dayschan 2020-12-30 14:25:40 +08:00
parent da7ce4a2e9
commit a661c3dd40
2 changed files with 3 additions and 3 deletions

View File

@ -277,8 +277,8 @@ class Lamb(Optimizer):
self.global_step = Parameter(initializer(0, [1]), name='global_step')
self.assignadd = P.AssignAdd()
self.hyper_map = C.HyperMap()
self.enable_graph_kernel = context.get_context("enable_graph_kernel") and \
context.get_context("device_target") == "Ascend"
self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \
context.get_context("enable_graph_kernel")
def construct(self, gradients):
lr = self.get_lr()

View File

@ -56,7 +56,7 @@ class _OpSelector:
def __call__(self, *args, **kwargs):
_op_type = _OpSelector.DEFAULT_OP_TYPE
if context.get_context("enable_graph_kernel"):
if context.get_context("device_target") in ['Ascend', 'GPU'] and context.get_context("enable_graph_kernel"):
if _OpSelector.KW_STR in kwargs:
_op_type = kwargs.get(_OpSelector.KW_STR)
kwargs.pop(_OpSelector.KW_STR, None)