forked from mindspore-Ecosystem/mindspore
!14634 fix mindspore.common.initializer.Normal
From: @wudenggang Reviewed-by: @kingxian,@zhoufeng54 Signed-off-by: @kingxian
This commit is contained in:
commit
a83c3dbce6
|
@ -366,25 +366,27 @@ class Uniform(Initializer):
|
||||||
@_register()
|
@_register()
|
||||||
class Normal(Initializer):
|
class Normal(Initializer):
|
||||||
"""
|
"""
|
||||||
Initialize a normal array, and obtain values N(0, sigma) from the uniform distribution
|
Initialize a normal array, and obtain values N(sigma, mean) from the normal distribution
|
||||||
to fill the input tensor.
|
to fill the input tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sigma (float): The sigma of the array. Default: 0.01.
|
sigma (float): The sigma of the array. Default: 0.01.
|
||||||
|
mean (float): The mean of the array. Default: 0.0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Array, normal array.
|
Array, normal array.
|
||||||
"""
|
"""
|
||||||
def __init__(self, sigma=0.01):
|
def __init__(self, sigma=0.01, mean=0.0):
|
||||||
super(Normal, self).__init__(sigma=sigma)
|
super(Normal, self).__init__(sigma=sigma, mean=mean)
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
|
self.mean = mean
|
||||||
|
|
||||||
def _initialize(self, arr):
|
def _initialize(self, arr):
|
||||||
seed, seed2 = self.seed
|
seed, seed2 = self.seed
|
||||||
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
|
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
|
||||||
random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor)
|
random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor)
|
||||||
output_data = output_tensor.asnumpy()
|
output_data = output_tensor.asnumpy()
|
||||||
output_data *= self.sigma
|
output_data = output_data * self.sigma + self.mean
|
||||||
_assignment(arr, output_data)
|
_assignment(arr, output_data)
|
||||||
|
|
||||||
@_register()
|
@_register()
|
||||||
|
|
Loading…
Reference in New Issue