forked from mindspore-Ecosystem/mindspore
!31216 custom julia only support cpu
Merge pull request !31216 from r1chardf1d0/julia2
This commit is contained in:
commit
d0342f9e92
|
@ -22,6 +22,7 @@ import hashlib
|
|||
import inspect
|
||||
import numpy as np
|
||||
from mindspore._c_expression import Oplib, typing
|
||||
from mindspore import context
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import DataType
|
||||
|
@ -685,9 +686,14 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
self.add_prim_attr("primitive_target", "CPU")
|
||||
elif self.func_type == "julia":
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
if registered_targets and registered_targets != ["CPU"]:
|
||||
device_target = context.get_context('device_target')
|
||||
if device_target == "CPU":
|
||||
pass
|
||||
elif device_target == "GPU" and registered_targets and registered_targets == ["CPU"]:
|
||||
logger.warning("CustomJulia only supports CPU platform, but gets registered target as {}."
|
||||
"We will run CustomJulia on CPU".format(registered_targets))
|
||||
else:
|
||||
raise ValueError("CustomJulia only supports CPU platform, but gets target as {}.".format(device_target))
|
||||
|
||||
def _update_attr(self):
|
||||
"""Add input_names, attr_names, primitive_target to primitive's attr."""
|
||||
|
|
Loading…
Reference in New Issue