!31216 custom julia only support cpu

Merge pull request !31216 from r1chardf1d0/julia2
This commit is contained in:
i-robot 2022-03-14 02:51:44 +00:00 committed by Gitee
commit d0342f9e92
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 7 additions and 1 deletions

View File

@ -22,6 +22,7 @@ import hashlib
import inspect
import numpy as np
from mindspore._c_expression import Oplib, typing
from mindspore import context
from mindspore.common import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import DataType
@ -685,9 +686,14 @@ class Custom(ops.PrimitiveWithInfer):
self.add_prim_attr("primitive_target", "CPU")
elif self.func_type == "julia":
self.add_prim_attr("primitive_target", "CPU")
if registered_targets and registered_targets != ["CPU"]:
device_target = context.get_context('device_target')
if device_target == "CPU":
pass
elif device_target == "GPU" and registered_targets and registered_targets == ["CPU"]:
logger.warning("CustomJulia only supports CPU platform, but gets registered target as {}."
"We will run CustomJulia on CPU".format(registered_targets))
else:
raise ValueError("CustomJulia only supports CPU platform, but gets target as {}.".format(device_target))
def _update_attr(self):
"""Add input_names, attr_names, primitive_target to primitive's attr."""