forked from mindspore-Ecosystem/mindspore
fix initiliazer
This commit is contained in:
parent
2a2dd7d340
commit
9081041199
|
@ -338,12 +338,6 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
"the variable shape {}.".format(list(init.shape()), shape))
|
||||
return init
|
||||
|
||||
if isinstance(init, str):
|
||||
init_obj = _INITIALIZER_ALIAS[init.lower()]()
|
||||
if init_obj is None:
|
||||
raise ValueError("The class corresponding to '{}' was not found.".format(init))
|
||||
init = init_obj
|
||||
|
||||
if isinstance(shape, list):
|
||||
shape = tuple(shape)
|
||||
elif isinstance(shape, numbers.Number):
|
||||
|
@ -354,6 +348,15 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
raise ValueError("Error shape={}".format(shape))
|
||||
|
||||
if isinstance(init, Initializer):
|
||||
init.shape = init.shape if init.shape is not None else shape
|
||||
init.dtype = init.dtype if init.dtype is not None else dtype
|
||||
return init
|
||||
|
||||
if isinstance(init, str):
|
||||
init_obj = _INITIALIZER_ALIAS[init.lower()]()
|
||||
if init_obj is None:
|
||||
raise ValueError("The class corresponding to '{}' was not found.".format(init))
|
||||
init = init_obj
|
||||
init.shape = shape
|
||||
init.dtype = dtype
|
||||
return init
|
||||
|
|
|
@ -141,7 +141,18 @@ def test_init_abnormal():
|
|||
with py.raises(TypeError):
|
||||
init.initializer([''], [5, 4], ms.float32)
|
||||
|
||||
|
||||
def test_initializer_reinit():
|
||||
weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16)
|
||||
assert weights.dtype == ms.float16
|
||||
assert weights.shape == (10, 1, 10, 10)
|
||||
weights = init.initializer(weights)
|
||||
assert weights.dtype == ms.float16
|
||||
assert weights.shape == (10, 1, 10, 10)
|
||||
weights.shape = None
|
||||
weights = init.initializer(weights, (10, 1))
|
||||
assert weights.dtype == ms.float16
|
||||
assert weights.shape == (10, 1)
|
||||
|
||||
def test_init_xavier_uniform():
|
||||
""" test_init_xavier_uniform """
|
||||
gain = 1.2
|
||||
|
|
Loading…
Reference in New Issue