add set device interface for prim

This commit is contained in:
kswang 2022-09-23 10:58:01 +08:00
parent 7253f40ad3
commit dc236c4753
30 changed files with 73 additions and 56 deletions

View File

@ -2607,7 +2607,7 @@ The following optimizers add the target interface: Adam, FTRL, LazyAdam, Proxim
>>>
>>> net = LeNet5()
>>> optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters()))
>>> optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
>>> optimizer.sparse_opt.set_device("CPU")
```
</td>

View File

@ -8,7 +8,7 @@ mindspore.nn.EmbeddingLookup
与嵌入层功能相同,主要用于自动并行或半自动并行时,存在大规模嵌入层的异构并行场景。
.. note::
当'target'设置为'CPU'时此模块将使用ops.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')在lookup表指定了'offset = 0'。
当'target'设置为'CPU'时此模块将使用ops.EmbeddingLookup().set_device('CPU')在lookup表指定了'offset = 0'。
当'target'设置为'DEVICE'时此模块将使用ops.Gather()在lookup表指定了'axis = 0'。
在字段切片模式下必须指定manual_shapes。此tuple包含vocab[i]元素, vocab[i]是第i部分的行号。

View File

@ -6,7 +6,7 @@ mindspore.nn.MultiFieldEmbeddingLookup
根据指定的索引和字段ID返回输入Tensor的切片。此操作支持同时使用multi hot和one hot查找嵌入。
.. note::
当'target'设置为'CPU'时此模块将使用P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')指定'offset = 0'的查找表。
当'target'设置为'CPU'时此模块将使用P.EmbeddingLookup().set_device('CPU')指定'offset = 0'的查找表。
当'target'设置为'DEVICE'时此模块将使用P.Gather()指定'axis = 0'的查找表。

View File

@ -16,6 +16,13 @@ mindspore.ops.Primitive
- **name** (str) - 属性名称。
- **value** (Any) - 属性值。
.. py:method:: set_device(device_target)
设置Primitive执行后端。
参数:
- **device_target** (str) - 后端名称支持CPU、GPU、Ascend。
.. py:method:: check_elim(*args)
检查是否可以消除此Primitive。有需要的子类可以重写该方法。

View File

@ -164,7 +164,7 @@ class EmbeddingLookup(Cell):
Note:
When 'target' is set to 'CPU', this module will use
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
P.EmbeddingLookup().set_device('CPU') which
specified 'offset = 0' to lookup table.
When 'target' is set to 'DEVICE', this module will use P.Gather() which
specified 'axis = 0' to lookup table.
@ -245,7 +245,7 @@ class EmbeddingLookup(Cell):
self.gatherv2 = P.SparseGatherV2()
else:
self.gatherv2 = P.Gather()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.embeddinglookup = P.EmbeddingLookup().set_device('CPU')
self.is_ps_server = False
enable_ps = _get_ps_context("enable_ps")
if enable_ps:
@ -401,7 +401,7 @@ class EmbeddingLookup(Cell):
# Add EmbeddingLookup ops on different servers.
if self.target == 'CPU':
embedding_lookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
embedding_lookup = P.EmbeddingLookup().set_device('CPU')
else:
if self.sparse:
embedding_lookup = P.SparseGatherV2()
@ -476,7 +476,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
Note:
When 'target' is set to 'CPU', this module will use
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
P.EmbeddingLookup().set_device('CPU') which
specified 'offset = 0' to lookup table.
When 'target' is set to 'DEVICE', this module will use P.Gather() which
specified 'axis = 0' to lookup table.

View File

@ -749,7 +749,7 @@ class EmbeddingLookupThor(Cell):
self.gatherv2 = P.SparseGatherV2()
else:
self.gatherv2 = P.Gather()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.embeddinglookup = P.EmbeddingLookup().set_device('CPU')
enable_ps = _get_ps_context("enable_ps")
if enable_ps:
self._process_vocab_cache(slice_mode)
@ -866,11 +866,11 @@ class EmbeddingLookupThor(Cell):
logger.info("EmbeddingLookup cache enable takes effect.")
self.forward_unique = True
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
self.unique = P.Unique().set_device('CPU')
self.unique.add_prim_attr('cache_enable', True)
self.embedding_table.cache_enable = self.cache_enable
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
self.reshape_first = P.Reshape().set_device('CPU')
def _process_vocab_cache(self, slice_mode):
"""PS embeddingLookup cache check and process."""

View File

@ -435,7 +435,7 @@ class AdaFactor(Optimizer):
"""
self._set_base_target(value)
if value == 'CPU':
self.fused_ada_factor.add_prim_attr("primitive_target", "CPU")
self.fused_ada_factor.set_device("CPU")
self.use_fused_ada_factor = True
else:
self.use_fused_ada_factor = False

View File

@ -515,7 +515,7 @@ class Adam(Optimizer):
else:
self.opt = P.Adam(use_locking, use_nesterov)
self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
self.sparse_opt.set_device("CPU")
self._ps_pull = P.Pull()
if use_amsgrad:
self._ps_push = P.Push("ApplyAdamWithAmsgrad", [0, 1, 2, 3])
@ -802,7 +802,7 @@ class AdamWeightDecay(Optimizer):
"""
self._set_base_target(value)
if value == 'CPU':
self.fused_opt.add_prim_attr("primitive_target", "CPU")
self.fused_opt.set_device("CPU")
self.use_fused_opt = True
else:
self.use_fused_opt = False
@ -963,7 +963,7 @@ class AdamOffload(Optimizer):
self.moment1 = self._parameters.clone(prefix="moment1", init='zeros')
self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
self.opt.add_prim_attr("primitive_target", "CPU")
self.opt.set_device("CPU")
@ms_function
def construct(self, gradients):

View File

@ -316,7 +316,7 @@ class FTRL(Optimizer):
if value == 'CPU':
self.sparse_opt = P.FusedSparseFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking)
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
self.sparse_opt.set_device("CPU")
else:
self.sparse_opt = P.SparseApplyFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking)

View File

@ -352,7 +352,7 @@ class LazyAdam(Optimizer):
self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
self.opt = P.Adam(use_locking, use_nesterov)
self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov)
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
self.sparse_opt.set_device("CPU")
self._ps_pull = P.Pull()
self._ps_push = P.Push("Adam", [0, 1, 2])
self._ps_push.add_prim_attr("use_nesterov", use_nesterov)

View File

@ -231,7 +231,8 @@ class ProximalAdagrad(Optimizer):
"but got {}.".format(value))
if value == 'CPU':
self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking).add_prim_attr("primitive_target", "CPU")
self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking)
self.sparse_opt.set_device("CPU")
else:
self.sparse_opt = P.SparseApplyProximalAdagrad(self.use_locking)

View File

@ -2283,7 +2283,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
self.add_prim_attr('primitive_target', 'CPU')
self.set_device('CPU')
self.tuple_setitem = Primitive('tuple_setitem')
def __infer__(self, dy, split_num):

View File

@ -879,7 +879,7 @@ class Custom(ops.PrimitiveWithInfer):
"""Add primitive_target to primitive's attr."""
registered_targets = self._get_registered_targets()
if self.func_type == "pyfunc":
self.add_prim_attr("primitive_target", "CPU")
self.set_device("CPU")
if registered_targets and registered_targets != ["CPU"]:
logger.warning("{}, only supports CPU platform, but got registered target {}. "
"We will run it on CPU".format(self.log_prefix, registered_targets))
@ -887,11 +887,11 @@ class Custom(ops.PrimitiveWithInfer):
if len(registered_targets) != 1:
logger.info("{}, target will be set according to context.".format(self.log_prefix))
elif registered_targets == ["GPU"]:
self.add_prim_attr("primitive_target", "GPU")
self.set_device("GPU")
elif registered_targets == ["CPU"]:
self.add_prim_attr("primitive_target", "CPU")
self.set_device("CPU")
elif self.func_type == "julia":
self.add_prim_attr("primitive_target", "CPU")
self.set_device("CPU")
device_target = context.get_context('device_target')
if device_target == "CPU":
pass

View File

@ -81,6 +81,15 @@ class Primitive(Primitive_):
self.add_attr(name, value)
return self
def set_device(self, device_target):
"""
Set primitive been executed device
Args:
device_target (str): The target device to run, support "Ascend", "GPU", and "CPU".
"""
return self.add_prim_attr("primitive_target", device_target)
def _fill_signature(self, signatures):
"""fills signature."""
signatures_new = []

View File

@ -189,7 +189,7 @@ class OptimizerSemiAutoAndAutoParallel6Net(Cell):
self.sigmoid = P.Sigmoid()
self.add1 = P.Add()
self.add2 = P.Add()
self.mul1 = P.Mul().add_prim_attr('primitive_target', 'CPU')
self.mul1 = P.Mul().set_device('CPU')
self.mul2 = P.Mul()
self.mul3 = P.Mul()
self.flatten = Flatten()

View File

@ -203,7 +203,7 @@ class ParallelStrategySearchNet(Cell):
self.tensor = Tensor(np.full(test_size, 0.05, dtype=np.float32))
self.softmax = Softmax(axis=axis)
self.relu = ReLU()
self.relu.relu.add_prim_attr("primitive_target", "CPU")
self.relu.relu.set_device("CPU")
self.reshape = P.Reshape()
self.input_shape = input_shape
self.equal = P.Equal()

View File

@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE)
class Unique(nn.Cell):
def __init__(self):
super(Unique, self).__init__()
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
self.unique_cpu = P.Unique().set_device("CPU")
def construct(self, x):
x, y = self.unique_cpu(x)
@ -36,7 +36,7 @@ class Unique(nn.Cell):
class UniqueSquare(nn.Cell):
def __init__(self):
super(UniqueSquare, self).__init__()
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
self.unique_cpu = P.Unique().set_device("CPU")
self.square = P.Square()
def construct(self, x):
@ -47,8 +47,8 @@ class UniqueSquare(nn.Cell):
class UniqueSquareRelu(nn.Cell):
def __init__(self):
super(UniqueSquareRelu, self).__init__()
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
self.square_cpu = P.Square().add_prim_attr("primitive_target", "CPU")
self.unique_cpu = P.Unique().set_device("CPU")
self.square_cpu = P.Square().set_device("CPU")
self.relu = P.ReLU()
def construct(self, x):
@ -60,9 +60,9 @@ class UniqueSquareRelu(nn.Cell):
class UniqueReshapeAdd(nn.Cell):
def __init__(self):
super(UniqueReshapeAdd, self).__init__()
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
self.unique_cpu = P.Unique().set_device("CPU")
self.unique = P.Unique()
self.reshape_cpu = P.Reshape().add_prim_attr("primitive_target", "CPU")
self.reshape_cpu = P.Reshape().set_device("CPU")
self.reshape = P.Reshape()
self.add = P.Add()

View File

@ -42,7 +42,7 @@ class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.relu1 = P.ReLU()
self.relu2 = P.ReLU().add_prim_attr("primitive_target", "CPU")
self.relu2 = P.ReLU().set_device("CPU")
self.mul = P.Mul()
self.depend = P.Depend()

View File

@ -138,7 +138,7 @@ class FusedAdamWeightDecayWithGlobalNorm(Optimizer):
self.moments2 = clone_state(self._parameters, prefix="adam_v", init='zeros')
self.norm = GlobalNorm()
self.opt = P.FusedCastAdamWeightDecay()
self.opt.add_prim_attr("primitive_target", "CPU")
self.opt.set_device("CPU")
def construct(self, gradients):
"""construct with gradients"""

View File

@ -36,14 +36,14 @@ class LeNet(nn.Cell):
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc1.matmul.add_prim_attr("primitive_target", "CPU")
self.fc1.bias_add.add_prim_attr("primitive_target", "CPU")
self.fc1.matmul.set_device("CPU")
self.fc1.bias_add.set_device("CPU")
self.fc2 = nn.Dense(120, 84)
self.fc2.matmul.add_prim_attr("primitive_target", "CPU")
self.fc2.bias_add.add_prim_attr("primitive_target", "CPU")
self.fc2.matmul.set_device("CPU")
self.fc2.bias_add.set_device("CPU")
self.fc3 = nn.Dense(84, 10)
self.fc3.matmul.add_prim_attr("primitive_target", "CPU")
self.fc3.bias_add.add_prim_attr("primitive_target", "CPU")
self.fc3.matmul.set_device("CPU")
self.fc3.bias_add.set_device("CPU")
def construct(self, input_x):
output = self.conv1(input_x)

View File

@ -77,8 +77,8 @@ def test_heterogeneous_excutor():
# heterogeneous_excutor
net_heter = Net(3, 10)
net_heter.relu.relu.add_prim_attr("primitive_target", "CPU")
net_heter.conv.conv2d.add_prim_attr("primitive_target", "CPU")
net_heter.relu.relu.set_device("CPU")
net_heter.conv.conv2d.set_device("CPU")
opt_heter = nn.Momentum(params=net_heter.trainable_params(),
learning_rate=0.001, momentum=0.0009,

View File

@ -58,7 +58,7 @@ class AdamWeightDecayOp(Optimizer):
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.opt = P.AdamWeightDecay()
self.opt.add_prim_attr("primitive_target", "CPU")
self.opt.set_device("CPU")
def construct(self, gradients):
"""AdamWeightDecayOp"""

View File

@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.embedding = P.EmbeddingLookup().add_prim_attr("primitive_target", "CPU")
self.embedding = P.EmbeddingLookup().set_device("CPU")
def construct(self, param, index, offset):
return self.embedding(param, index, offset)

View File

@ -142,8 +142,8 @@ class NetFactory:
context.set_ps_context(enable_ps=False)
net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size,
self.embedding_size, self.output_channels, self.target, self.sparse)
net.conv.conv2d.add_prim_attr('primitive_target', 'CPU')
net.conv.bias_add.add_prim_attr('primitive_target', 'CPU')
net.conv.conv2d.set_device('CPU')
net.conv.bias_add.set_device('CPU')
net.set_train()
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()))
@ -159,8 +159,8 @@ class NetFactory:
net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size,
self.embedding_size, self.output_channels, self.target, self.sparse)
net.embedding_lookup.set_param_ps()
net.conv.conv2d.add_prim_attr('primitive_target', 'CPU')
net.conv.bias_add.add_prim_attr('primitive_target', 'CPU')
net.conv.conv2d.set_device('CPU')
net.conv.bias_add.set_device('CPU')
net.set_train()
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()))

View File

@ -49,7 +49,7 @@ def test_heterogeneous_default_ascend_prim_cpu():
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
output_device = net(inp1, inp2)
net.relu1.add_prim_attr("primitive_target", "CPU")
net.relu1.set_device("CPU")
output_heter = net(inp1, inp2)
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)
@ -67,7 +67,7 @@ def test_heterogeneous_default_cpu_prim_ascend():
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
output_device = net(inp1, inp2)
net.relu1.add_prim_attr("primitive_target", "Ascend")
net.relu1.set_device("Ascend")
output_heter = net(inp1, inp2)
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)
@ -85,7 +85,7 @@ def test_heterogeneous_default_gpu_prim_cpu():
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
output_device = net(inp1, inp2)
net.relu1.add_prim_attr("primitive_target", "CPU")
net.relu1.set_device("CPU")
output_heter = net(inp1, inp2)
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)
@ -103,6 +103,6 @@ def test_heterogeneous_default_cpu_prim_gpu():
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
output_device = net(inp1, inp2)
net.relu1.add_prim_attr("primitive_target", "GPU")
net.relu1.set_device("GPU")
output_heter = net(inp1, inp2)
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)

View File

@ -48,7 +48,7 @@ class Net(nn.Cell):
super().__init__()
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.offset = offset
self.elu = P.EmbeddingLookup().shard(strategy1).add_prim_attr("primitive_target", target)
self.elu = P.EmbeddingLookup().shard(strategy1).set_device(target)
self.mm = P.BatchMatMul().shard(strategy2)
def construct(self, x, y):

View File

@ -54,7 +54,7 @@ class Net(nn.Cell):
super().__init__()
if shape is None:
shape = [64, 64]
self.gatherv2 = P.Gather().shard(strategy1, gather_out_strategy).add_prim_attr("primitive_target", target)
self.gatherv2 = P.Gather().shard(strategy1, gather_out_strategy).set_device(target)
self.mul = P.Mul().shard(strategy2)
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis

View File

@ -43,7 +43,7 @@ class Net(Cell):
super().__init__()
self.gatherv2 = P.EmbeddingLookup().shard(strategy1)
self.gatherv2.add_prim_attr(split_string, split_tuple)
self.gatherv2.add_prim_attr("primitive_target", "CPU")
self.gatherv2.set_device("CPU")
self.mul = P.Mul().shard(strategy2)
self.reshape = P.Reshape()
self.matmul = P.MatMul().shard(strategy3)
@ -68,7 +68,7 @@ _b = Tensor(np.ones([8, 8, 8]), dtype=ms.float32)
def compile_net(net):
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1)
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
optimizer.sparse_opt.set_device("CPU")
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()

View File

@ -57,7 +57,7 @@ class Net(nn.Cell):
if shape is None:
shape = [64, 64]
self.gatherv2 = P.Gather().shard(
strategy1, gather_out_strategy).add_prim_attr("primitive_target", target)
strategy1, gather_out_strategy).set_device(target)
self.mul = P.Mul().shard(strategy2)
self.softmax = P.Softmax().shard(strategy3)
self.index = Tensor(np.ones(shape), dtype=ms.int32)

View File

@ -56,7 +56,7 @@ class Net(nn.Cell):
super().__init__()
if shape is None:
shape = [64, 64]
self.gatherv2 = P.SparseGatherV2().shard(strategy1).add_prim_attr("primitive_target", target)
self.gatherv2 = P.SparseGatherV2().shard(strategy1).set_device(target)
self.mul = P.Mul().shard(strategy2)
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis