forked from mindspore-Ecosystem/mindspore
!10828 fix warning for enable_graph_kernel context in CPU device
From: @dayschan Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
76bd0f1245
|
@ -277,8 +277,8 @@ class Lamb(Optimizer):
|
||||||
self.global_step = Parameter(initializer(0, [1]), name='global_step')
|
self.global_step = Parameter(initializer(0, [1]), name='global_step')
|
||||||
self.assignadd = P.AssignAdd()
|
self.assignadd = P.AssignAdd()
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
self.enable_graph_kernel = context.get_context("enable_graph_kernel") and \
|
self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \
|
||||||
context.get_context("device_target") == "Ascend"
|
context.get_context("enable_graph_kernel")
|
||||||
|
|
||||||
def construct(self, gradients):
|
def construct(self, gradients):
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
|
|
|
@ -56,7 +56,7 @@ class _OpSelector:
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
_op_type = _OpSelector.DEFAULT_OP_TYPE
|
_op_type = _OpSelector.DEFAULT_OP_TYPE
|
||||||
if context.get_context("enable_graph_kernel"):
|
if context.get_context("device_target") in ['Ascend', 'GPU'] and context.get_context("enable_graph_kernel"):
|
||||||
if _OpSelector.KW_STR in kwargs:
|
if _OpSelector.KW_STR in kwargs:
|
||||||
_op_type = kwargs.get(_OpSelector.KW_STR)
|
_op_type = kwargs.get(_OpSelector.KW_STR)
|
||||||
kwargs.pop(_OpSelector.KW_STR, None)
|
kwargs.pop(_OpSelector.KW_STR, None)
|
||||||
|
|
Loading…
Reference in New Issue