!14634 fix mindspore.common.initializer.Normal

From: @wudenggang
Reviewed-by: @kingxian,@zhoufeng54
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2021-04-13 11:03:02 +08:00 committed by Gitee
commit a83c3dbce6
1 changed files with 6 additions and 4 deletions

View File

@ -366,25 +366,27 @@ class Uniform(Initializer):
@_register()
class Normal(Initializer):
"""
Initialize a normal array, and obtain values N(0, sigma) from the uniform distribution
Initialize a normal array, and obtain values N(sigma, mean) from the normal distribution
to fill the input tensor.
Args:
sigma (float): The sigma of the array. Default: 0.01.
mean (float): The mean of the array. Default: 0.0.
Returns:
Array, normal array.
"""
def __init__(self, sigma=0.01):
super(Normal, self).__init__(sigma=sigma)
def __init__(self, sigma=0.01, mean=0.0):
super(Normal, self).__init__(sigma=sigma, mean=mean)
self.sigma = sigma
self.mean = mean
def _initialize(self, arr):
seed, seed2 = self.seed
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor)
output_data = output_tensor.asnumpy()
output_data *= self.sigma
output_data = output_data * self.sigma + self.mean
_assignment(arr, output_data)
@_register()