forked from mindspore-Ecosystem/mindspore
fix warning for enable_graph_kernel context in CPU device
This commit is contained in:
parent
da7ce4a2e9
commit
a661c3dd40
|
@ -277,8 +277,8 @@ class Lamb(Optimizer):
|
|||
self.global_step = Parameter(initializer(0, [1]), name='global_step')
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.enable_graph_kernel = context.get_context("enable_graph_kernel") and \
|
||||
context.get_context("device_target") == "Ascend"
|
||||
self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \
|
||||
context.get_context("enable_graph_kernel")
|
||||
|
||||
def construct(self, gradients):
|
||||
lr = self.get_lr()
|
||||
|
|
|
@ -56,7 +56,7 @@ class _OpSelector:
|
|||
|
||||
def __call__(self, *args, **kwargs):
|
||||
_op_type = _OpSelector.DEFAULT_OP_TYPE
|
||||
if context.get_context("enable_graph_kernel"):
|
||||
if context.get_context("device_target") in ['Ascend', 'GPU'] and context.get_context("enable_graph_kernel"):
|
||||
if _OpSelector.KW_STR in kwargs:
|
||||
_op_type = kwargs.get(_OpSelector.KW_STR)
|
||||
kwargs.pop(_OpSelector.KW_STR, None)
|
||||
|
|
Loading…
Reference in New Issue