!7607 modify_normal_seed
Merge pull request !7607 from lilei/modify_normal
This commit is contained in:
commit
d50736df2c
|
@ -19,7 +19,7 @@ import math
|
|||
from functools import reduce
|
||||
import numpy as np
|
||||
from scipy.stats import truncnorm
|
||||
from .seed import _get_graph_seed
|
||||
from .seed import get_seed, _get_graph_seed
|
||||
from . import dtype as mstype
|
||||
from .tensor import Tensor, MetaTensor
|
||||
from .._c_expression import random_normal
|
||||
|
@ -44,7 +44,9 @@ class Initializer:
|
|||
|
||||
@property
|
||||
def seed(self):
|
||||
seed_ = self._seed if self._seed is not None else 1
|
||||
seed_ = self._seed if self._seed is not None else get_seed()
|
||||
if seed_ is None:
|
||||
seed_ = 1
|
||||
_, seed = _get_graph_seed(seed_, "init")
|
||||
return seed
|
||||
|
||||
|
@ -410,7 +412,7 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32.
|
||||
|
||||
Returns:
|
||||
Union[Tensor, Initializer], When `init` is Tensor, the return is Tensor object,
|
||||
Union[Tensor, MetaTensor], When `init` is Tensor, the return is Tensor object,
|
||||
otherwise the return is Initialize object.
|
||||
|
||||
Examples:
|
||||
|
|
|
@ -16,19 +16,9 @@ import numpy as np
|
|||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def weight_variable(shape):
|
||||
return initializer('XavierUniform', shape=shape, dtype=mstype.float32)
|
||||
|
||||
|
||||
def weight_variable_uniform(shape):
|
||||
return initializer('Uniform', shape=shape, dtype=mstype.float32)
|
||||
|
||||
|
||||
def weight_variable_0(shape):
|
||||
zeros = np.zeros(shape).astype(np.float32)
|
||||
return Tensor(zeros)
|
||||
|
@ -41,26 +31,23 @@ def weight_variable_1(shape):
|
|||
|
||||
def conv3x3(in_channels, out_channels, stride=1, padding=0):
|
||||
"""3x3 convolution """
|
||||
weight_shape = (out_channels, in_channels, 3, 3)
|
||||
weight = weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
|
||||
kernel_size=3, stride=stride, padding=padding, weight_init='XavierUniform',
|
||||
has_bias=False, pad_mode="same")
|
||||
|
||||
|
||||
def conv1x1(in_channels, out_channels, stride=1, padding=0):
|
||||
"""1x1 convolution"""
|
||||
weight_shape = (out_channels, in_channels, 1, 1)
|
||||
weight = weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
|
||||
kernel_size=1, stride=stride, padding=padding, weight_init='XavierUniform',
|
||||
has_bias=False, pad_mode="same")
|
||||
|
||||
|
||||
def conv7x7(in_channels, out_channels, stride=1, padding=0):
|
||||
"""1x1 convolution"""
|
||||
weight_shape = (out_channels, in_channels, 7, 7)
|
||||
weight = weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
|
||||
kernel_size=7, stride=stride, padding=padding, weight_init='Uniform',
|
||||
has_bias=False, pad_mode="same")
|
||||
|
||||
|
||||
def bn_with_initialize(out_channels):
|
||||
|
@ -68,8 +55,7 @@ def bn_with_initialize(out_channels):
|
|||
mean = weight_variable_0(shape)
|
||||
var = weight_variable_1(shape)
|
||||
beta = weight_variable_0(shape)
|
||||
gamma = weight_variable_uniform(shape)
|
||||
bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma,
|
||||
bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform',
|
||||
beta_init=beta, moving_mean_init=mean, moving_var_init=var)
|
||||
return bn
|
||||
|
||||
|
@ -79,18 +65,13 @@ def bn_with_initialize_last(out_channels):
|
|||
mean = weight_variable_0(shape)
|
||||
var = weight_variable_1(shape)
|
||||
beta = weight_variable_0(shape)
|
||||
gamma = weight_variable_uniform(shape)
|
||||
bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma,
|
||||
bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init='Uniform',
|
||||
beta_init=beta, moving_mean_init=mean, moving_var_init=var)
|
||||
return bn
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
weight_shape = (out_channels, input_channels)
|
||||
weight = weight_variable(weight_shape)
|
||||
bias_shape = (out_channels)
|
||||
bias = weight_variable_uniform(bias_shape)
|
||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||
return nn.Dense(input_channels, out_channels, weight_init='XavierUniform', bias_init='Uniform')
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from numpy import allclose
|
||||
from mindspore.common import set_seed
|
||||
import mindspore.common.initializer as init
|
||||
|
@ -54,10 +53,10 @@ def test_using_same_seed_for_initializer():
|
|||
|
||||
|
||||
def test_using_diffserent_seed_for_initializer():
|
||||
np.random.seed(0)
|
||||
set_seed(0)
|
||||
net1 = ParameterNet()
|
||||
net1.init_parameters_data()
|
||||
np.random.seed(1)
|
||||
set_seed(1)
|
||||
net2 = ParameterNet()
|
||||
net2.init_parameters_data()
|
||||
for key in net1.parameters_dict():
|
||||
|
|
Loading…
Reference in New Issue