Fixes #I3AB6T

This commit is contained in:
wudenggang 2021-04-04 11:26:16 +08:00
parent c99fe1e412
commit 696d80607d
1 changed files with 6 additions and 4 deletions

View File

@ -366,25 +366,27 @@ class Uniform(Initializer):
@_register()
class Normal(Initializer):
"""
Initialize a normal array, and obtain values N(0, sigma) from the uniform distribution
Initialize a normal array, and obtain values N(sigma, mean) from the normal distribution
to fill the input tensor.
Args:
sigma (float): The sigma of the array. Default: 0.01.
mean (float): The mean of the array. Default: 0.0.
Returns:
Array, normal array.
"""
def __init__(self, sigma=0.01):
super(Normal, self).__init__(sigma=sigma)
def __init__(self, sigma=0.01, mean=0.0):
super(Normal, self).__init__(sigma=sigma, mean=mean)
self.sigma = sigma
self.mean = mean
def _initialize(self, arr):
seed, seed2 = self.seed
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor)
output_data = output_tensor.asnumpy()
output_data *= self.sigma
output_data = output_data * self.sigma + self.mean
_assignment(arr, output_data)
@_register()