forked from mindspore-Ecosystem/mindspore
add set device interface for prim
This commit is contained in:
parent
7253f40ad3
commit
dc236c4753
|
@ -2607,7 +2607,7 @@ The following optimizers add the target interface: Adam, FTRL, LazyAdam, Proxim
|
|||
>>>
|
||||
>>> net = LeNet5()
|
||||
>>> optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
>>> optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||
>>> optimizer.sparse_opt.set_device("CPU")
|
||||
```
|
||||
|
||||
</td>
|
||||
|
|
|
@ -8,7 +8,7 @@ mindspore.nn.EmbeddingLookup
|
|||
与嵌入层功能相同,主要用于自动并行或半自动并行时,存在大规模嵌入层的异构并行场景。
|
||||
|
||||
.. note::
|
||||
当'target'设置为'CPU'时,此模块将使用ops.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU'),在lookup表指定了'offset = 0'。
|
||||
当'target'设置为'CPU'时,此模块将使用ops.EmbeddingLookup().set_device('CPU'),在lookup表指定了'offset = 0'。
|
||||
当'target'设置为'DEVICE'时,此模块将使用ops.Gather(),在lookup表指定了'axis = 0'。
|
||||
在字段切片模式下,必须指定manual_shapes。此tuple包含vocab[i]元素, vocab[i]是第i部分的行号。
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ mindspore.nn.MultiFieldEmbeddingLookup
|
|||
根据指定的索引和字段ID,返回输入Tensor的切片。此操作支持同时使用multi hot和one hot查找嵌入。
|
||||
|
||||
.. note::
|
||||
当'target'设置为'CPU'时,此模块将使用P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')指定'offset = 0'的查找表。
|
||||
当'target'设置为'CPU'时,此模块将使用P.EmbeddingLookup().set_device('CPU')指定'offset = 0'的查找表。
|
||||
|
||||
当'target'设置为'DEVICE'时,此模块将使用P.Gather()指定'axis = 0'的查找表。
|
||||
|
||||
|
|
|
@ -15,6 +15,13 @@ mindspore.ops.Primitive
|
|||
参数:
|
||||
- **name** (str) - 属性名称。
|
||||
- **value** (Any) - 属性值。
|
||||
|
||||
.. py:method:: set_device(device_target)
|
||||
|
||||
设置Primitive执行后端。
|
||||
|
||||
参数:
|
||||
- **device_target** (str) - 后端名称,支持CPU、GPU、Ascend。
|
||||
|
||||
.. py:method:: check_elim(*args)
|
||||
|
||||
|
|
|
@ -164,7 +164,7 @@ class EmbeddingLookup(Cell):
|
|||
|
||||
Note:
|
||||
When 'target' is set to 'CPU', this module will use
|
||||
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
||||
P.EmbeddingLookup().set_device('CPU') which
|
||||
specified 'offset = 0' to lookup table.
|
||||
When 'target' is set to 'DEVICE', this module will use P.Gather() which
|
||||
specified 'axis = 0' to lookup table.
|
||||
|
@ -245,7 +245,7 @@ class EmbeddingLookup(Cell):
|
|||
self.gatherv2 = P.SparseGatherV2()
|
||||
else:
|
||||
self.gatherv2 = P.Gather()
|
||||
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
||||
self.embeddinglookup = P.EmbeddingLookup().set_device('CPU')
|
||||
self.is_ps_server = False
|
||||
enable_ps = _get_ps_context("enable_ps")
|
||||
if enable_ps:
|
||||
|
@ -401,7 +401,7 @@ class EmbeddingLookup(Cell):
|
|||
|
||||
# Add EmbeddingLookup ops on different servers.
|
||||
if self.target == 'CPU':
|
||||
embedding_lookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
||||
embedding_lookup = P.EmbeddingLookup().set_device('CPU')
|
||||
else:
|
||||
if self.sparse:
|
||||
embedding_lookup = P.SparseGatherV2()
|
||||
|
@ -476,7 +476,7 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
|
||||
Note:
|
||||
When 'target' is set to 'CPU', this module will use
|
||||
P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
|
||||
P.EmbeddingLookup().set_device('CPU') which
|
||||
specified 'offset = 0' to lookup table.
|
||||
When 'target' is set to 'DEVICE', this module will use P.Gather() which
|
||||
specified 'axis = 0' to lookup table.
|
||||
|
|
|
@ -749,7 +749,7 @@ class EmbeddingLookupThor(Cell):
|
|||
self.gatherv2 = P.SparseGatherV2()
|
||||
else:
|
||||
self.gatherv2 = P.Gather()
|
||||
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
||||
self.embeddinglookup = P.EmbeddingLookup().set_device('CPU')
|
||||
enable_ps = _get_ps_context("enable_ps")
|
||||
if enable_ps:
|
||||
self._process_vocab_cache(slice_mode)
|
||||
|
@ -866,11 +866,11 @@ class EmbeddingLookupThor(Cell):
|
|||
|
||||
logger.info("EmbeddingLookup cache enable takes effect.")
|
||||
self.forward_unique = True
|
||||
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
|
||||
self.unique = P.Unique().set_device('CPU')
|
||||
self.unique.add_prim_attr('cache_enable', True)
|
||||
self.embedding_table.cache_enable = self.cache_enable
|
||||
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
|
||||
self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
|
||||
self.reshape_first = P.Reshape().set_device('CPU')
|
||||
|
||||
def _process_vocab_cache(self, slice_mode):
|
||||
"""PS embeddingLookup cache check and process."""
|
||||
|
|
|
@ -435,7 +435,7 @@ class AdaFactor(Optimizer):
|
|||
"""
|
||||
self._set_base_target(value)
|
||||
if value == 'CPU':
|
||||
self.fused_ada_factor.add_prim_attr("primitive_target", "CPU")
|
||||
self.fused_ada_factor.set_device("CPU")
|
||||
self.use_fused_ada_factor = True
|
||||
else:
|
||||
self.use_fused_ada_factor = False
|
||||
|
|
|
@ -515,7 +515,7 @@ class Adam(Optimizer):
|
|||
else:
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)
|
||||
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||
self.sparse_opt.set_device("CPU")
|
||||
self._ps_pull = P.Pull()
|
||||
if use_amsgrad:
|
||||
self._ps_push = P.Push("ApplyAdamWithAmsgrad", [0, 1, 2, 3])
|
||||
|
@ -802,7 +802,7 @@ class AdamWeightDecay(Optimizer):
|
|||
"""
|
||||
self._set_base_target(value)
|
||||
if value == 'CPU':
|
||||
self.fused_opt.add_prim_attr("primitive_target", "CPU")
|
||||
self.fused_opt.set_device("CPU")
|
||||
self.use_fused_opt = True
|
||||
else:
|
||||
self.use_fused_opt = False
|
||||
|
@ -963,7 +963,7 @@ class AdamOffload(Optimizer):
|
|||
self.moment1 = self._parameters.clone(prefix="moment1", init='zeros')
|
||||
self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
|
||||
self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
|
||||
self.opt.add_prim_attr("primitive_target", "CPU")
|
||||
self.opt.set_device("CPU")
|
||||
|
||||
@ms_function
|
||||
def construct(self, gradients):
|
||||
|
|
|
@ -316,7 +316,7 @@ class FTRL(Optimizer):
|
|||
|
||||
if value == 'CPU':
|
||||
self.sparse_opt = P.FusedSparseFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking)
|
||||
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||
self.sparse_opt.set_device("CPU")
|
||||
else:
|
||||
self.sparse_opt = P.SparseApplyFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking)
|
||||
|
||||
|
|
|
@ -352,7 +352,7 @@ class LazyAdam(Optimizer):
|
|||
self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
|
||||
self.opt = P.Adam(use_locking, use_nesterov)
|
||||
self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov)
|
||||
self.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||
self.sparse_opt.set_device("CPU")
|
||||
self._ps_pull = P.Pull()
|
||||
self._ps_push = P.Push("Adam", [0, 1, 2])
|
||||
self._ps_push.add_prim_attr("use_nesterov", use_nesterov)
|
||||
|
|
|
@ -231,7 +231,8 @@ class ProximalAdagrad(Optimizer):
|
|||
"but got {}.".format(value))
|
||||
|
||||
if value == 'CPU':
|
||||
self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking).add_prim_attr("primitive_target", "CPU")
|
||||
self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking)
|
||||
self.sparse_opt.set_device("CPU")
|
||||
else:
|
||||
self.sparse_opt = P.SparseApplyProximalAdagrad(self.use_locking)
|
||||
|
||||
|
|
|
@ -2283,7 +2283,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
|
||||
self.add_prim_attr('primitive_target', 'CPU')
|
||||
self.set_device('CPU')
|
||||
self.tuple_setitem = Primitive('tuple_setitem')
|
||||
|
||||
def __infer__(self, dy, split_num):
|
||||
|
|
|
@ -879,7 +879,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
"""Add primitive_target to primitive's attr."""
|
||||
registered_targets = self._get_registered_targets()
|
||||
if self.func_type == "pyfunc":
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
self.set_device("CPU")
|
||||
if registered_targets and registered_targets != ["CPU"]:
|
||||
logger.warning("{}, only supports CPU platform, but got registered target {}. "
|
||||
"We will run it on CPU".format(self.log_prefix, registered_targets))
|
||||
|
@ -887,11 +887,11 @@ class Custom(ops.PrimitiveWithInfer):
|
|||
if len(registered_targets) != 1:
|
||||
logger.info("{}, target will be set according to context.".format(self.log_prefix))
|
||||
elif registered_targets == ["GPU"]:
|
||||
self.add_prim_attr("primitive_target", "GPU")
|
||||
self.set_device("GPU")
|
||||
elif registered_targets == ["CPU"]:
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
self.set_device("CPU")
|
||||
elif self.func_type == "julia":
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
self.set_device("CPU")
|
||||
device_target = context.get_context('device_target')
|
||||
if device_target == "CPU":
|
||||
pass
|
||||
|
|
|
@ -81,6 +81,15 @@ class Primitive(Primitive_):
|
|||
self.add_attr(name, value)
|
||||
return self
|
||||
|
||||
def set_device(self, device_target):
|
||||
"""
|
||||
Set primitive been executed device
|
||||
|
||||
Args:
|
||||
device_target (str): The target device to run, support "Ascend", "GPU", and "CPU".
|
||||
"""
|
||||
return self.add_prim_attr("primitive_target", device_target)
|
||||
|
||||
def _fill_signature(self, signatures):
|
||||
"""fills signature."""
|
||||
signatures_new = []
|
||||
|
|
|
@ -189,7 +189,7 @@ class OptimizerSemiAutoAndAutoParallel6Net(Cell):
|
|||
self.sigmoid = P.Sigmoid()
|
||||
self.add1 = P.Add()
|
||||
self.add2 = P.Add()
|
||||
self.mul1 = P.Mul().add_prim_attr('primitive_target', 'CPU')
|
||||
self.mul1 = P.Mul().set_device('CPU')
|
||||
self.mul2 = P.Mul()
|
||||
self.mul3 = P.Mul()
|
||||
self.flatten = Flatten()
|
||||
|
|
|
@ -203,7 +203,7 @@ class ParallelStrategySearchNet(Cell):
|
|||
self.tensor = Tensor(np.full(test_size, 0.05, dtype=np.float32))
|
||||
self.softmax = Softmax(axis=axis)
|
||||
self.relu = ReLU()
|
||||
self.relu.relu.add_prim_attr("primitive_target", "CPU")
|
||||
self.relu.relu.set_device("CPU")
|
||||
self.reshape = P.Reshape()
|
||||
self.input_shape = input_shape
|
||||
self.equal = P.Equal()
|
||||
|
|
|
@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE)
|
|||
class Unique(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Unique, self).__init__()
|
||||
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
|
||||
self.unique_cpu = P.Unique().set_device("CPU")
|
||||
|
||||
def construct(self, x):
|
||||
x, y = self.unique_cpu(x)
|
||||
|
@ -36,7 +36,7 @@ class Unique(nn.Cell):
|
|||
class UniqueSquare(nn.Cell):
|
||||
def __init__(self):
|
||||
super(UniqueSquare, self).__init__()
|
||||
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
|
||||
self.unique_cpu = P.Unique().set_device("CPU")
|
||||
self.square = P.Square()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -47,8 +47,8 @@ class UniqueSquare(nn.Cell):
|
|||
class UniqueSquareRelu(nn.Cell):
|
||||
def __init__(self):
|
||||
super(UniqueSquareRelu, self).__init__()
|
||||
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
|
||||
self.square_cpu = P.Square().add_prim_attr("primitive_target", "CPU")
|
||||
self.unique_cpu = P.Unique().set_device("CPU")
|
||||
self.square_cpu = P.Square().set_device("CPU")
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -60,9 +60,9 @@ class UniqueSquareRelu(nn.Cell):
|
|||
class UniqueReshapeAdd(nn.Cell):
|
||||
def __init__(self):
|
||||
super(UniqueReshapeAdd, self).__init__()
|
||||
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
|
||||
self.unique_cpu = P.Unique().set_device("CPU")
|
||||
self.unique = P.Unique()
|
||||
self.reshape_cpu = P.Reshape().add_prim_attr("primitive_target", "CPU")
|
||||
self.reshape_cpu = P.Reshape().set_device("CPU")
|
||||
self.reshape = P.Reshape()
|
||||
self.add = P.Add()
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ class Net2(nn.Cell):
|
|||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.relu1 = P.ReLU()
|
||||
self.relu2 = P.ReLU().add_prim_attr("primitive_target", "CPU")
|
||||
self.relu2 = P.ReLU().set_device("CPU")
|
||||
self.mul = P.Mul()
|
||||
self.depend = P.Depend()
|
||||
|
||||
|
|
|
@ -138,7 +138,7 @@ class FusedAdamWeightDecayWithGlobalNorm(Optimizer):
|
|||
self.moments2 = clone_state(self._parameters, prefix="adam_v", init='zeros')
|
||||
self.norm = GlobalNorm()
|
||||
self.opt = P.FusedCastAdamWeightDecay()
|
||||
self.opt.add_prim_attr("primitive_target", "CPU")
|
||||
self.opt.set_device("CPU")
|
||||
|
||||
def construct(self, gradients):
|
||||
"""construct with gradients"""
|
||||
|
|
|
@ -36,14 +36,14 @@ class LeNet(nn.Cell):
|
|||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.reshape = P.Reshape()
|
||||
self.fc1 = nn.Dense(400, 120)
|
||||
self.fc1.matmul.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc1.bias_add.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc1.matmul.set_device("CPU")
|
||||
self.fc1.bias_add.set_device("CPU")
|
||||
self.fc2 = nn.Dense(120, 84)
|
||||
self.fc2.matmul.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc2.bias_add.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc2.matmul.set_device("CPU")
|
||||
self.fc2.bias_add.set_device("CPU")
|
||||
self.fc3 = nn.Dense(84, 10)
|
||||
self.fc3.matmul.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc3.bias_add.add_prim_attr("primitive_target", "CPU")
|
||||
self.fc3.matmul.set_device("CPU")
|
||||
self.fc3.bias_add.set_device("CPU")
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.conv1(input_x)
|
||||
|
|
|
@ -77,8 +77,8 @@ def test_heterogeneous_excutor():
|
|||
|
||||
# heterogeneous_excutor
|
||||
net_heter = Net(3, 10)
|
||||
net_heter.relu.relu.add_prim_attr("primitive_target", "CPU")
|
||||
net_heter.conv.conv2d.add_prim_attr("primitive_target", "CPU")
|
||||
net_heter.relu.relu.set_device("CPU")
|
||||
net_heter.conv.conv2d.set_device("CPU")
|
||||
|
||||
opt_heter = nn.Momentum(params=net_heter.trainable_params(),
|
||||
learning_rate=0.001, momentum=0.0009,
|
||||
|
|
|
@ -58,7 +58,7 @@ class AdamWeightDecayOp(Optimizer):
|
|||
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
|
||||
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
|
||||
self.opt = P.AdamWeightDecay()
|
||||
self.opt.add_prim_attr("primitive_target", "CPU")
|
||||
self.opt.set_device("CPU")
|
||||
|
||||
def construct(self, gradients):
|
||||
"""AdamWeightDecayOp"""
|
||||
|
|
|
@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
|||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.embedding = P.EmbeddingLookup().add_prim_attr("primitive_target", "CPU")
|
||||
self.embedding = P.EmbeddingLookup().set_device("CPU")
|
||||
|
||||
def construct(self, param, index, offset):
|
||||
return self.embedding(param, index, offset)
|
||||
|
|
|
@ -142,8 +142,8 @@ class NetFactory:
|
|||
context.set_ps_context(enable_ps=False)
|
||||
net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size,
|
||||
self.embedding_size, self.output_channels, self.target, self.sparse)
|
||||
net.conv.conv2d.add_prim_attr('primitive_target', 'CPU')
|
||||
net.conv.bias_add.add_prim_attr('primitive_target', 'CPU')
|
||||
net.conv.conv2d.set_device('CPU')
|
||||
net.conv.bias_add.set_device('CPU')
|
||||
net.set_train()
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
|
@ -159,8 +159,8 @@ class NetFactory:
|
|||
net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size,
|
||||
self.embedding_size, self.output_channels, self.target, self.sparse)
|
||||
net.embedding_lookup.set_param_ps()
|
||||
net.conv.conv2d.add_prim_attr('primitive_target', 'CPU')
|
||||
net.conv.bias_add.add_prim_attr('primitive_target', 'CPU')
|
||||
net.conv.conv2d.set_device('CPU')
|
||||
net.conv.bias_add.set_device('CPU')
|
||||
net.set_train()
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
opt = Adam(params=filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
|
|
|
@ -49,7 +49,7 @@ def test_heterogeneous_default_ascend_prim_cpu():
|
|||
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
output_device = net(inp1, inp2)
|
||||
net.relu1.add_prim_attr("primitive_target", "CPU")
|
||||
net.relu1.set_device("CPU")
|
||||
output_heter = net(inp1, inp2)
|
||||
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)
|
||||
|
||||
|
@ -67,7 +67,7 @@ def test_heterogeneous_default_cpu_prim_ascend():
|
|||
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
output_device = net(inp1, inp2)
|
||||
net.relu1.add_prim_attr("primitive_target", "Ascend")
|
||||
net.relu1.set_device("Ascend")
|
||||
output_heter = net(inp1, inp2)
|
||||
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)
|
||||
|
||||
|
@ -85,7 +85,7 @@ def test_heterogeneous_default_gpu_prim_cpu():
|
|||
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
output_device = net(inp1, inp2)
|
||||
net.relu1.add_prim_attr("primitive_target", "CPU")
|
||||
net.relu1.set_device("CPU")
|
||||
output_heter = net(inp1, inp2)
|
||||
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)
|
||||
|
||||
|
@ -103,6 +103,6 @@ def test_heterogeneous_default_cpu_prim_gpu():
|
|||
inp1 = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
inp2 = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
output_device = net(inp1, inp2)
|
||||
net.relu1.add_prim_attr("primitive_target", "GPU")
|
||||
net.relu1.set_device("GPU")
|
||||
output_heter = net(inp1, inp2)
|
||||
assert np.allclose(output_device.asnumpy(), output_heter.asnumpy(), 1e-6, 1e-6)
|
||||
|
|
|
@ -48,7 +48,7 @@ class Net(nn.Cell):
|
|||
super().__init__()
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.offset = offset
|
||||
self.elu = P.EmbeddingLookup().shard(strategy1).add_prim_attr("primitive_target", target)
|
||||
self.elu = P.EmbeddingLookup().shard(strategy1).set_device(target)
|
||||
self.mm = P.BatchMatMul().shard(strategy2)
|
||||
|
||||
def construct(self, x, y):
|
||||
|
|
|
@ -54,7 +54,7 @@ class Net(nn.Cell):
|
|||
super().__init__()
|
||||
if shape is None:
|
||||
shape = [64, 64]
|
||||
self.gatherv2 = P.Gather().shard(strategy1, gather_out_strategy).add_prim_attr("primitive_target", target)
|
||||
self.gatherv2 = P.Gather().shard(strategy1, gather_out_strategy).set_device(target)
|
||||
self.mul = P.Mul().shard(strategy2)
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.axis = axis
|
||||
|
|
|
@ -43,7 +43,7 @@ class Net(Cell):
|
|||
super().__init__()
|
||||
self.gatherv2 = P.EmbeddingLookup().shard(strategy1)
|
||||
self.gatherv2.add_prim_attr(split_string, split_tuple)
|
||||
self.gatherv2.add_prim_attr("primitive_target", "CPU")
|
||||
self.gatherv2.set_device("CPU")
|
||||
self.mul = P.Mul().shard(strategy2)
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul().shard(strategy3)
|
||||
|
@ -68,7 +68,7 @@ _b = Tensor(np.ones([8, 8, 8]), dtype=ms.float32)
|
|||
|
||||
def compile_net(net):
|
||||
optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1)
|
||||
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||
optimizer.sparse_opt.set_device("CPU")
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
|
|
|
@ -57,7 +57,7 @@ class Net(nn.Cell):
|
|||
if shape is None:
|
||||
shape = [64, 64]
|
||||
self.gatherv2 = P.Gather().shard(
|
||||
strategy1, gather_out_strategy).add_prim_attr("primitive_target", target)
|
||||
strategy1, gather_out_strategy).set_device(target)
|
||||
self.mul = P.Mul().shard(strategy2)
|
||||
self.softmax = P.Softmax().shard(strategy3)
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
|
|
|
@ -56,7 +56,7 @@ class Net(nn.Cell):
|
|||
super().__init__()
|
||||
if shape is None:
|
||||
shape = [64, 64]
|
||||
self.gatherv2 = P.SparseGatherV2().shard(strategy1).add_prim_attr("primitive_target", target)
|
||||
self.gatherv2 = P.SparseGatherV2().shard(strategy1).set_device(target)
|
||||
self.mul = P.Mul().shard(strategy2)
|
||||
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||
self.axis = axis
|
||||
|
|
Loading…
Reference in New Issue