forked from mindspore-Ecosystem/mindspore
!1853 Fix initializer
Merge pull request !1853 from amongo/FixInitializer
This commit is contained in:
commit
bd34c6ec8b
|
@ -338,12 +338,6 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
||||||
"the variable shape {}.".format(list(init.shape()), shape))
|
"the variable shape {}.".format(list(init.shape()), shape))
|
||||||
return init
|
return init
|
||||||
|
|
||||||
if isinstance(init, str):
|
|
||||||
init_obj = _INITIALIZER_ALIAS[init.lower()]()
|
|
||||||
if init_obj is None:
|
|
||||||
raise ValueError("The class corresponding to '{}' was not found.".format(init))
|
|
||||||
init = init_obj
|
|
||||||
|
|
||||||
if isinstance(shape, list):
|
if isinstance(shape, list):
|
||||||
shape = tuple(shape)
|
shape = tuple(shape)
|
||||||
elif isinstance(shape, numbers.Number):
|
elif isinstance(shape, numbers.Number):
|
||||||
|
@ -354,6 +348,15 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
||||||
raise ValueError("Error shape={}".format(shape))
|
raise ValueError("Error shape={}".format(shape))
|
||||||
|
|
||||||
if isinstance(init, Initializer):
|
if isinstance(init, Initializer):
|
||||||
|
init.shape = init.shape if init.shape is not None else shape
|
||||||
|
init.dtype = init.dtype if init.dtype is not None else dtype
|
||||||
|
return init
|
||||||
|
|
||||||
|
if isinstance(init, str):
|
||||||
|
init_obj = _INITIALIZER_ALIAS[init.lower()]()
|
||||||
|
if init_obj is None:
|
||||||
|
raise ValueError("The class corresponding to '{}' was not found.".format(init))
|
||||||
|
init = init_obj
|
||||||
init.shape = shape
|
init.shape = shape
|
||||||
init.dtype = dtype
|
init.dtype = dtype
|
||||||
return init
|
return init
|
||||||
|
|
|
@ -141,7 +141,18 @@ def test_init_abnormal():
|
||||||
with py.raises(TypeError):
|
with py.raises(TypeError):
|
||||||
init.initializer([''], [5, 4], ms.float32)
|
init.initializer([''], [5, 4], ms.float32)
|
||||||
|
|
||||||
|
def test_initializer_reinit():
|
||||||
|
weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16)
|
||||||
|
assert weights.dtype == ms.float16
|
||||||
|
assert weights.shape == (10, 1, 10, 10)
|
||||||
|
weights = init.initializer(weights)
|
||||||
|
assert weights.dtype == ms.float16
|
||||||
|
assert weights.shape == (10, 1, 10, 10)
|
||||||
|
weights.shape = None
|
||||||
|
weights = init.initializer(weights, (10, 1))
|
||||||
|
assert weights.dtype == ms.float16
|
||||||
|
assert weights.shape == (10, 1)
|
||||||
|
|
||||||
def test_init_xavier_uniform():
|
def test_init_xavier_uniform():
|
||||||
""" test_init_xavier_uniform """
|
""" test_init_xavier_uniform """
|
||||||
gain = 1.2
|
gain = 1.2
|
||||||
|
|
Loading…
Reference in New Issue