iterfaces change: _Constant to Constant
This commit is contained in:
parent
04eab166e5
commit
bb5320be13
|
@ -180,18 +180,18 @@ class HeUniform(Initializer):
|
|||
_assignment(arr, data)
|
||||
|
||||
|
||||
class _Constant(Initializer):
|
||||
class Constant(Initializer):
|
||||
"""
|
||||
Initialize a constant.
|
||||
|
||||
Args:
|
||||
value (int or numpy.ndarray): The value to initialize.
|
||||
value (Union[int, numpy.ndarray]): The value to initialize.
|
||||
|
||||
Returns:
|
||||
Array, initialize array.
|
||||
"""
|
||||
def __init__(self, value):
|
||||
super(_Constant, self).__init__(value=value)
|
||||
super(Constant, self).__init__(value=value)
|
||||
self.value = value
|
||||
|
||||
def _initialize(self, arr):
|
||||
|
@ -266,8 +266,16 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
|
||||
Args:
|
||||
init (Union[Tensor, str, Initializer, numbers.Number]): Initialize value.
|
||||
|
||||
- `str`: The `init` should be the alias of the class inheriting from `Initializer` and the corresponding
|
||||
class will be called.
|
||||
|
||||
- `Initializer`: The `init` should be the class inheriting from `Initializer` to initialize tensor.
|
||||
|
||||
- `numbers.Number`: The `Constant` will be called to initialize tensor.
|
||||
|
||||
shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
|
||||
output. Default: None.
|
||||
output. Default: None.
|
||||
dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mstype.float32.
|
||||
|
||||
Returns:
|
||||
|
@ -295,7 +303,7 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
raise ValueError(msg)
|
||||
|
||||
if isinstance(init, numbers.Number):
|
||||
init_obj = _Constant(init)
|
||||
init_obj = Constant(init)
|
||||
elif isinstance(init, str):
|
||||
init_obj = _INITIALIZER_ALIAS[init.lower()]()
|
||||
else:
|
||||
|
@ -314,4 +322,5 @@ __all__ = [
|
|||
'HeUniform',
|
||||
'XavierUniform',
|
||||
'One',
|
||||
'Zero']
|
||||
'Zero',
|
||||
'Constant']
|
||||
|
|
|
@ -37,8 +37,8 @@ def _check_value(tensor, value_min, value_max):
|
|||
for ele in nd.flatten():
|
||||
if value_min <= ele <= value_max:
|
||||
continue
|
||||
raise TypeError('value_min = %d, ele = %d, value_max = %d'
|
||||
% (value_min, ele, value_max))
|
||||
raise ValueError('value_min = %d, ele = %d, value_max = %d'
|
||||
% (value_min, ele, value_max))
|
||||
|
||||
|
||||
def _check_uniform(tensor, boundary_a, boundary_b):
|
||||
|
@ -92,6 +92,11 @@ def test_init_one_alias():
|
|||
_check_value(tensor, 1, 1)
|
||||
|
||||
|
||||
def test_init_constant():
|
||||
tensor = init.initializer(init.Constant(1), [2, 2], ms.float32)
|
||||
_check_value(tensor, 1, 1)
|
||||
|
||||
|
||||
def test_init_uniform():
|
||||
scale = 10
|
||||
tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32)
|
||||
|
|
Loading…
Reference in New Issue