forked from mindspore-Ecosystem/mindspore
!4643 add test cases for parameter slice init using all kind of initializers
Merge pull request !4643 from yihuaijie/dev
This commit is contained in:
commit
025e6f23f7
|
@ -23,7 +23,7 @@ from mindspore.common.initializer import initializer
|
||||||
from hccl_test.manage.api import Hccl
|
from hccl_test.manage.api import Hccl
|
||||||
|
|
||||||
|
|
||||||
def test_initializer_weight_slice():
|
def check_initializer_weight_slice(init_name="Uniform"):
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, strategy1, strategy2, weight):
|
def __init__(self, strategy1, strategy2, weight):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -49,7 +49,7 @@ def test_initializer_weight_slice():
|
||||||
exe = me._executor
|
exe = me._executor
|
||||||
|
|
||||||
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||||
weight = initializer("Uniform", [64, 32], ms.float32)
|
weight = initializer(init_name, [64, 32], ms.float32)
|
||||||
net = Net(strategy1, strategy2, weight)
|
net = Net(strategy1, strategy2, weight)
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
exe.compile(net, x, auto_parallel_mode=True, phase='train')
|
exe.compile(net, x, auto_parallel_mode=True, phase='train')
|
||||||
|
@ -68,8 +68,14 @@ def test_initializer_weight_slice():
|
||||||
|
|
||||||
assert expect_slice_shape == slice_shape
|
assert expect_slice_shape == slice_shape
|
||||||
assert all(slice0 == slice4)
|
assert all(slice0 == slice4)
|
||||||
assert any(slice0 != slice1)
|
if init_name not in ["One", "Zero"]:
|
||||||
|
assert any(slice0 != slice1)
|
||||||
|
|
||||||
|
initializers = ["Uniform", "Normal", "TruncatedNormal", "HeUniform", "HeNormal", "XavierUniform", "One", "Zero"]
|
||||||
|
|
||||||
|
def test_initializer_weight_slice():
|
||||||
|
for init_name in initializers:
|
||||||
|
check_initializer_weight_slice(init_name)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_initializer_weight_slice()
|
test_initializer_weight_slice()
|
||||||
|
|
Loading…
Reference in New Issue