fix initiliazer

This commit is contained in:
huangdongrun 2020-06-04 21:36:54 +08:00
parent 2a2dd7d340
commit 9081041199
2 changed files with 21 additions and 7 deletions

View File

@ -338,12 +338,6 @@ def initializer(init, shape=None, dtype=mstype.float32):
"the variable shape {}.".format(list(init.shape()), shape))
return init
if isinstance(init, str):
init_obj = _INITIALIZER_ALIAS[init.lower()]()
if init_obj is None:
raise ValueError("The class corresponding to '{}' was not found.".format(init))
init = init_obj
if isinstance(shape, list):
shape = tuple(shape)
elif isinstance(shape, numbers.Number):
@ -354,6 +348,15 @@ def initializer(init, shape=None, dtype=mstype.float32):
raise ValueError("Error shape={}".format(shape))
if isinstance(init, Initializer):
init.shape = init.shape if init.shape is not None else shape
init.dtype = init.dtype if init.dtype is not None else dtype
return init
if isinstance(init, str):
init_obj = _INITIALIZER_ALIAS[init.lower()]()
if init_obj is None:
raise ValueError("The class corresponding to '{}' was not found.".format(init))
init = init_obj
init.shape = shape
init.dtype = dtype
return init

View File

@ -141,6 +141,17 @@ def test_init_abnormal():
with py.raises(TypeError):
init.initializer([''], [5, 4], ms.float32)
def test_initializer_reinit():
weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16)
assert weights.dtype == ms.float16
assert weights.shape == (10, 1, 10, 10)
weights = init.initializer(weights)
assert weights.dtype == ms.float16
assert weights.shape == (10, 1, 10, 10)
weights.shape = None
weights = init.initializer(weights, (10, 1))
assert weights.dtype == ms.float16
assert weights.shape == (10, 1)
def test_init_xavier_uniform():
""" test_init_xavier_uniform """