forked from mindspore-Ecosystem/mindspore
!33925 parallel_initializer_seed
Merge pull request !33925 from yao_yf/auto_parallel_initializer_seed
This commit is contained in:
commit
291b5faa98
|
@ -9,6 +9,7 @@ mindspore.set_seed
|
|||
- 全局种子可用于numpy.random, mindspore.common.Initializer, mindspore.ops.composite.random_ops以及mindspore.nn.probability.distribution。
|
||||
- 如果没有设置全局种子,这些包将会各自使用自己的种子,numpy.random和mindspore.common.Initializer将会随机选择种子值,mindspore.ops.composite.random_ops和mindspore.nn.probability.distribution将会使用零作为种子值。
|
||||
- numpy.random.seed()设置的种子仅能被numpy.random使用,而这个API设置的种子也可被numpy.random使用,因此推荐使用这个API设置所有的种子。
|
||||
- 在semi_auto_parallel/auto_parallel模式下,使用set_seed时,具有相同形状和相同切分策略的权重将被初始化为相同的结果,否则,将被初始化为不同的结果。
|
||||
|
||||
**参数:**
|
||||
|
||||
|
|
|
@ -49,6 +49,10 @@ def set_seed(seed):
|
|||
Seed set by numpy.random.seed() only used by numpy.random, while seed set by this API will also used by
|
||||
numpy.random, so just set all seed by this API is recommended.
|
||||
|
||||
In semi_auto_parallel/auto_parallel mode, when using set_seed, weights with same shape and same sharding
|
||||
strategy would be initialized to the same result, otherwise, weights with same shape and same sharding strategy
|
||||
would be initialized to the different result.
|
||||
|
||||
Args:
|
||||
seed (int): The seed to be set.
|
||||
|
||||
|
|
|
@ -110,6 +110,7 @@ class Tensor(Tensor_):
|
|||
>>> print(t4.dtype)
|
||||
Float32
|
||||
"""
|
||||
delta_seed = 0
|
||||
|
||||
def __init__(self, input_data=None, dtype=None, shape=None, init=None, internal=False):
|
||||
self.init_finished = False
|
||||
|
@ -2184,13 +2185,23 @@ class Tensor(Tensor_):
|
|||
from .seed import get_seed
|
||||
global_seed = get_seed()
|
||||
self._np_seed = np.random.get_state()[1][0]
|
||||
self.need_set_seed = ((slice_index is not None) and (global_seed is None))
|
||||
self.need_set_seed = (slice_index is not None)
|
||||
self._global_seed = global_seed
|
||||
self._device_num = 1
|
||||
if self.need_set_seed:
|
||||
self._device_num = get_group_size()
|
||||
|
||||
def __enter__(self):
|
||||
if self.need_set_seed:
|
||||
self.seed = self.init.seed
|
||||
np.random.seed(slice_index)
|
||||
self.init.seed = slice_index
|
||||
if self._global_seed is not None:
|
||||
np.random.seed(slice_index + self._global_seed)
|
||||
self.init.seed = slice_index + self._global_seed
|
||||
else:
|
||||
np.random.seed(slice_index + Tensor.delta_seed)
|
||||
self.init.seed = slice_index + Tensor.delta_seed
|
||||
Tensor.delta_seed += self._device_num
|
||||
|
||||
|
||||
def __exit__(self, ptype, value, trace):
|
||||
if self.need_set_seed:
|
||||
|
|
|
@ -381,6 +381,6 @@ def test_train_feed(num_classes=65536):
|
|||
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
|
||||
loss_value = np.array(parallel_callback.loss_list)
|
||||
expect_out = [11.105435, 10.954063, 10.460592]
|
||||
expect_out = [11.087254, 10.876551, 10.045684]
|
||||
print(loss_value)
|
||||
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
|
||||
|
|
|
@ -57,6 +57,11 @@ class Hccl():
|
|||
# pylint: disable=unused-argument
|
||||
def get_rank_id(group=None):
|
||||
hccl = Hccl()
|
||||
if group is not None:
|
||||
group_size = get_rank_size(group)
|
||||
rank = hccl.rank_id
|
||||
if rank >= group_size:
|
||||
return rank % group_size
|
||||
return hccl.rank_id
|
||||
|
||||
|
||||
|
|
|
@ -25,25 +25,31 @@ from mindspore.common import set_seed
|
|||
from hccl_test.manage.api import Hccl
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, weight):
|
||||
def __init__(self, strategy1, strategy2, weight1, weight2):
|
||||
super().__init__()
|
||||
self.weight = Parameter(weight, "w1")
|
||||
self.matmul = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1)
|
||||
self.weight1 = Parameter(weight1, "w1")
|
||||
self.weight2 = Parameter(weight2, "w2")
|
||||
self.matmul1 = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1)
|
||||
self.matmul2 = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1)
|
||||
self.relu = P.ReLU().shard(strategy2)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.matmul(x, self.weight)
|
||||
out = self.matmul1(x, self.weight1)
|
||||
out = self.matmul2(out, self.weight2)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
def check_initializer_weight_slice(init_name="Uniform"):
|
||||
|
||||
def check_initializer_weight_slice(init_name="Uniform", using_seed=False):
|
||||
def get_slice(rank):
|
||||
if using_seed:
|
||||
set_seed(1)
|
||||
hccl = Hccl()
|
||||
rank_save = hccl.rank_id
|
||||
hccl.rank_id = rank
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=rank)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 1), (4, 1))
|
||||
strategy2 = ((2, 4),)
|
||||
|
@ -51,40 +57,65 @@ def check_initializer_weight_slice(init_name="Uniform"):
|
|||
exe = me._cell_graph_executor
|
||||
|
||||
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
weight = initializer(init_name, [64, 32], ms.float32)
|
||||
net = Net(strategy1, strategy2, weight)
|
||||
weight1 = initializer(init_name, [32, 32], ms.float32)
|
||||
weight2 = initializer(init_name, [32, 32], ms.float32)
|
||||
net = Net(strategy1, strategy2, weight1, weight2)
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
exe.compile(net, x, auto_parallel_mode=True, phase='train')
|
||||
hccl.rank_id = rank_save
|
||||
return net.parameters_dict()['w1'].data.asnumpy()
|
||||
return net.parameters_dict()['w1'].data.asnumpy(), net.parameters_dict()['w2'].data.asnumpy()
|
||||
|
||||
slice0 = get_slice(0)
|
||||
slice1 = get_slice(1)
|
||||
slice4 = get_slice(4)
|
||||
slice_shape = slice0.shape
|
||||
Tensor.delta_seed = 0
|
||||
w1_slice0, w2_slice0 = get_slice(0)
|
||||
Tensor.delta_seed = 0
|
||||
w1_slice1, _ = get_slice(1)
|
||||
Tensor.delta_seed = 0
|
||||
w1_slice4, _ = get_slice(4)
|
||||
slice_shape = w1_slice0.shape
|
||||
|
||||
slice0 = slice0.flatten()
|
||||
slice1 = slice1.flatten()
|
||||
slice4 = slice4.flatten()
|
||||
expect_slice_shape = (16, 32)
|
||||
w1_slice0 = w1_slice0.flatten()
|
||||
w1_slice1 = w1_slice1.flatten()
|
||||
w1_slice4 = w1_slice4.flatten()
|
||||
w2_slice0 = w2_slice0.flatten()
|
||||
expect_slice_shape = (8, 32)
|
||||
|
||||
assert expect_slice_shape == slice_shape
|
||||
assert all(slice0 == slice4)
|
||||
assert all(w1_slice0 == w1_slice4)
|
||||
if init_name not in ["One", "Zero"]:
|
||||
assert any(slice0 != slice1)
|
||||
assert any(w1_slice0 != w1_slice1)
|
||||
if using_seed:
|
||||
assert all(w1_slice0 == w2_slice0)
|
||||
else:
|
||||
assert any(w1_slice0 != w2_slice0)
|
||||
|
||||
|
||||
initializers = ["Uniform", "Normal", "TruncatedNormal", "HeUniform", "HeNormal", "XavierUniform", "One", "Zero"]
|
||||
|
||||
|
||||
def test_initializer_weight_slice():
|
||||
"""
|
||||
Feature: test initializer in auto parallel with/without using set_seed.
|
||||
Description: test initializer in auto parallel with/without using set_seed.
|
||||
Expectation: without any assert error.
|
||||
"""
|
||||
for init_name in initializers:
|
||||
check_initializer_weight_slice(init_name)
|
||||
for init_name in initializers:
|
||||
check_initializer_weight_slice(init_name, using_seed=True)
|
||||
|
||||
|
||||
def test_wrong_order_set_parallel_mode_with_initializer():
|
||||
weight = initializer("Normal", [64, 32], ms.float32)
|
||||
"""
|
||||
Feature: test parameter initialize in auto parallel.
|
||||
Description: test parameter initialize in auto parallel applying initializer before setting auto parallel mode.
|
||||
Expectation: without any assert error.
|
||||
"""
|
||||
weight1 = initializer("Normal", [32, 32], ms.float32)
|
||||
weight2 = initializer("Normal", [32, 32], ms.float32)
|
||||
strategy1 = ((2, 1), (4, 1))
|
||||
strategy2 = ((2, 4),)
|
||||
net = Net(strategy1, strategy2, weight)
|
||||
net = Net(strategy1, strategy2, weight1, weight2)
|
||||
exe = me._cell_graph_executor
|
||||
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
|
@ -92,67 +123,39 @@ def test_wrong_order_set_parallel_mode_with_initializer():
|
|||
with pytest.raises(RuntimeError):
|
||||
exe.compile(net, x, auto_parallel_mode=True, phase='train')
|
||||
|
||||
|
||||
def test_wrong_order_set_same_parallel_mode_with_initializer():
|
||||
"""
|
||||
Feature: test parameter initialize in auto parallel.
|
||||
Description: test parameter initialize in auto parallel applying initializer after setting auto parallel mode.
|
||||
Expectation: without any assert error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
weight = initializer("Normal", [64, 32], ms.float32)
|
||||
weight1 = initializer("Normal", [32, 32], ms.float32)
|
||||
weight2 = initializer("Normal", [32, 32], ms.float32)
|
||||
strategy1 = ((2, 1), (4, 1))
|
||||
strategy2 = ((2, 4),)
|
||||
net = Net(strategy1, strategy2, weight)
|
||||
net = Net(strategy1, strategy2, weight1, weight2)
|
||||
exe = me._cell_graph_executor
|
||||
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net.set_auto_parallel()
|
||||
exe.compile(net, x, auto_parallel_mode=True, phase='train')
|
||||
|
||||
|
||||
def test_wrong_order_set_parallel_mode_without_initializer():
|
||||
weight = Tensor(np.ones([64, 32]), ms.float32)
|
||||
"""
|
||||
Feature: test parameter initialize in auto parallel.
|
||||
Description: test parameter initialize in auto parallel not using initializer.
|
||||
Expectation: without any assert error.
|
||||
"""
|
||||
weight1 = Tensor(np.ones([32, 32]), ms.float32)
|
||||
weight2 = Tensor(np.ones([32, 32]), ms.float32)
|
||||
strategy1 = ((2, 1), (4, 1))
|
||||
strategy2 = ((2, 4),)
|
||||
net = Net(strategy1, strategy2, weight)
|
||||
net = Net(strategy1, strategy2, weight1, weight2)
|
||||
exe = me._cell_graph_executor
|
||||
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
net.set_auto_parallel()
|
||||
exe.compile(net, x, auto_parallel_mode=True, phase='train')
|
||||
|
||||
def test_check_initializer_weight_slice_seed(init_name="Uniform"):
|
||||
def get_slice(rank):
|
||||
set_seed(1)
|
||||
hccl = Hccl()
|
||||
rank_save = hccl.rank_id
|
||||
hccl.rank_id = rank
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 1), (4, 1))
|
||||
strategy2 = ((2, 4),)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
exe = me._cell_graph_executor
|
||||
|
||||
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
weight = initializer(init_name, [64, 32], ms.float32)
|
||||
net = Net(strategy1, strategy2, weight)
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
exe.compile(net, x, auto_parallel_mode=True, phase='train')
|
||||
hccl.rank_id = rank_save
|
||||
return net.parameters_dict()['w1'].data.asnumpy()
|
||||
|
||||
|
||||
slice0 = get_slice(0)
|
||||
slice1 = get_slice(1)
|
||||
slice4 = get_slice(4)
|
||||
slice_shape = slice0.shape
|
||||
|
||||
slice0 = slice0.flatten()
|
||||
slice1 = slice1.flatten()
|
||||
slice4 = slice4.flatten()
|
||||
expect_slice_shape = (16, 32)
|
||||
|
||||
assert expect_slice_shape == slice_shape
|
||||
assert all(slice0 == slice4)
|
||||
assert all(slice0 == slice1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_initializer_weight_slice()
|
||||
|
|
Loading…
Reference in New Issue