forked from mindspore-Ecosystem/mindspore
!4326 add kaiming normal init
Merge pull request !4326 from baihuawei/kaiming
This commit is contained in:
commit
3e215aaac2
|
@ -151,6 +151,84 @@ class One(Initializer):
|
|||
_assignment(arr, 1)
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(shape):
|
||||
"""
|
||||
calculate fan_in and fan_out
|
||||
|
||||
Args:
|
||||
shape (tuple): input shape.
|
||||
|
||||
Returns:
|
||||
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
|
||||
"""
|
||||
dimensions = len(shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = shape[1]
|
||||
fan_out = shape[0]
|
||||
else:
|
||||
num_input_fmaps = shape[1]
|
||||
num_output_fmaps = shape[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = shape[2] * shape[3]
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def _calculate_correct_fan(shape, mode):
|
||||
"""
|
||||
Calculate fan.
|
||||
|
||||
Args:
|
||||
shape (tuple): input shape.
|
||||
mode (str): only support fan_in and fan_out.
|
||||
|
||||
Returns:
|
||||
fan_in or fan_out.
|
||||
"""
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(shape)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def _calculate_gain(nonlinearity, param=None):
|
||||
"""
|
||||
Calculate gain.
|
||||
|
||||
Args:
|
||||
nonlinearity (str): nonlinearity function.
|
||||
param (str): used to calculate negative_slope.
|
||||
|
||||
Returns:
|
||||
number.
|
||||
"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
res = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
res = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
res = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return res
|
||||
|
||||
|
||||
def _calculate_in_and_out(arr):
|
||||
"""
|
||||
Calculate n_in and n_out.
|
||||
|
@ -223,6 +301,35 @@ class HeUniform(Initializer):
|
|||
_assignment(arr, data)
|
||||
|
||||
|
||||
@_register('he_normal')
|
||||
class HeNormal(Initializer):
|
||||
r"""
|
||||
Initialize the array with He kaiming Normal algorithm, and from a normal distribution collect samples within
|
||||
N(0, sigma).
|
||||
|
||||
Args:
|
||||
negative_slope (int, float, bool): Default: 0, used when nonlinearity is 'leaky_relu'.
|
||||
mode (str): Default: fan_in.
|
||||
nonlinearity (str): Default: leaky_relu.
|
||||
|
||||
Returns:
|
||||
Array, assigned array.
|
||||
"""
|
||||
def __init__(self, negative_slope=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(HeNormal, self).__init__(negative_slope=negative_slope, mode=mode, nonlinearity=nonlinearity)
|
||||
self.negative_slope = negative_slope
|
||||
self.mode = mode
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def _initialize(self, arr):
|
||||
fan = _calculate_correct_fan(arr.shape, self.mode)
|
||||
gain = _calculate_gain(self.nonlinearity, self.negative_slope)
|
||||
std = gain / math.sqrt(fan)
|
||||
data = np.random.normal(0, std, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
||||
|
||||
class Constant(Initializer):
|
||||
"""
|
||||
Initialize a constant.
|
||||
|
@ -372,6 +479,7 @@ __all__ = [
|
|||
'Normal',
|
||||
'Uniform',
|
||||
'HeUniform',
|
||||
'HeNormal',
|
||||
'XavierUniform',
|
||||
'One',
|
||||
'Zero',
|
||||
|
|
Loading…
Reference in New Issue