forked from mindspore-Ecosystem/mindspore
!14634 fix mindspore.common.initializer.Normal
From: @wudenggang Reviewed-by: @kingxian,@zhoufeng54 Signed-off-by: @kingxian
This commit is contained in:
commit
a83c3dbce6
|
@ -366,25 +366,27 @@ class Uniform(Initializer):
|
|||
@_register()
|
||||
class Normal(Initializer):
|
||||
"""
|
||||
Initialize a normal array, and obtain values N(0, sigma) from the uniform distribution
|
||||
Initialize a normal array, and obtain values N(sigma, mean) from the normal distribution
|
||||
to fill the input tensor.
|
||||
|
||||
Args:
|
||||
sigma (float): The sigma of the array. Default: 0.01.
|
||||
mean (float): The mean of the array. Default: 0.0.
|
||||
|
||||
Returns:
|
||||
Array, normal array.
|
||||
"""
|
||||
def __init__(self, sigma=0.01):
|
||||
super(Normal, self).__init__(sigma=sigma)
|
||||
def __init__(self, sigma=0.01, mean=0.0):
|
||||
super(Normal, self).__init__(sigma=sigma, mean=mean)
|
||||
self.sigma = sigma
|
||||
self.mean = mean
|
||||
|
||||
def _initialize(self, arr):
|
||||
seed, seed2 = self.seed
|
||||
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
|
||||
random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor)
|
||||
output_data = output_tensor.asnumpy()
|
||||
output_data *= self.sigma
|
||||
output_data = output_data * self.sigma + self.mean
|
||||
_assignment(arr, output_data)
|
||||
|
||||
@_register()
|
||||
|
|
Loading…
Reference in New Issue