From a661c3dd40ca2a7a4cda27df532258b9a99d9837 Mon Sep 17 00:00:00 2001 From: dayschan Date: Wed, 30 Dec 2020 14:25:40 +0800 Subject: [PATCH] fix warning for enable_graph_kernel context in CPU device --- mindspore/nn/optim/lamb.py | 4 ++-- mindspore/ops/op_selector.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index 706cba4267e..19be613615b 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -277,8 +277,8 @@ class Lamb(Optimizer): self.global_step = Parameter(initializer(0, [1]), name='global_step') self.assignadd = P.AssignAdd() self.hyper_map = C.HyperMap() - self.enable_graph_kernel = context.get_context("enable_graph_kernel") and \ - context.get_context("device_target") == "Ascend" + self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \ + context.get_context("enable_graph_kernel") def construct(self, gradients): lr = self.get_lr() diff --git a/mindspore/ops/op_selector.py b/mindspore/ops/op_selector.py index 2020a161d26..cff0a330f03 100644 --- a/mindspore/ops/op_selector.py +++ b/mindspore/ops/op_selector.py @@ -56,7 +56,7 @@ class _OpSelector: def __call__(self, *args, **kwargs): _op_type = _OpSelector.DEFAULT_OP_TYPE - if context.get_context("enable_graph_kernel"): + if context.get_context("device_target") in ['Ascend', 'GPU'] and context.get_context("enable_graph_kernel"): if _OpSelector.KW_STR in kwargs: _op_type = kwargs.get(_OpSelector.KW_STR) kwargs.pop(_OpSelector.KW_STR, None)