This commit is contained in:
Wei Luning 2020-08-04 10:51:06 +08:00
parent cdc5131869
commit e5fc529159
13 changed files with 82 additions and 60 deletions

View File

@ -340,7 +340,8 @@ class Parameter(MetaTensor):
Default: False. Default: False.
Returns: Returns:
Parameter, Parameter after init data. Parameter, the `Parameter` after init data. If current `Parameter` already initialized before,
returns the same initialized `Parameter`.
""" """
if self.init_mode is None: if self.init_mode is None:
return self return self

View File

@ -536,6 +536,10 @@ class Cell:
""" """
Init all parameters' data and replace the original saved parameters in cell. Init all parameters' data and replace the original saved parameters in cell.
Notes:
trainable_params() and other similar interfaces may return different parameter instance after
`init_parameters_data`, do not save these result.
Args: Args:
auto_parallel_mode (bool): If running in auto_parallel_mode. auto_parallel_mode (bool): If running in auto_parallel_mode.

View File

@ -425,12 +425,13 @@ class Optimizer(Cell):
raise TypeError(f"The parameter only support 'Parameter' or 'list' type.") raise TypeError(f"The parameter only support 'Parameter' or 'list' type.")
lr = [] lr = []
ids = [id(p) for p in self.parameters]
for p in param_list: for p in param_list:
validator.check_value_type("parameter", p, [Parameter], self.cls_name) validator.check_value_type("parameter", p, [Parameter], self.cls_name)
if p not in self.parameters: if id(p) not in ids:
raise ValueError(f"The parameter {p.name} is not in optimizer.") raise ValueError(f"The parameter {p.name} is not in optimizer.")
if self.is_group_lr: if self.is_group_lr:
index = self.parameters.index(p) index = ids.index(id(p))
lr.append(get_lr_value(self.learning_rate[index])) lr.append(get_lr_value(self.learning_rate[index]))
else: else:
lr.append(get_lr_value(self.learning_rate)) lr.append(get_lr_value(self.learning_rate))

View File

@ -84,8 +84,14 @@ if __name__ == '__main__':
lr = Tensor(lr) lr = Tensor(lr)
# optimizer # optimizer
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params())) decayed_params = []
no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay}, group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
{'params': no_decayed_params}, {'params': no_decayed_params},
{'order_params': net.trainable_params()}] {'order_params': net.trainable_params()}]

View File

@ -290,7 +290,6 @@ class MobileNetV3(nn.Cell):
kernel_size=1, has_bias=True, pad_mode='pad') kernel_size=1, has_bias=True, pad_mode='pad')
self.squeeze = P.Squeeze(axis=(2, 3)) self.squeeze = P.Squeeze(axis=(2, 3))
self.init_parameters_data()
self._initialize_weights() self._initialize_weights()
def construct(self, x): def construct(self, x):
@ -320,6 +319,7 @@ class MobileNetV3(nn.Cell):
Examples: Examples:
>>> _initialize_weights() >>> _initialize_weights()
""" """
self.init_parameters_data()
for _, m in self.cells_and_names(): for _, m in self.cells_and_names():
if isinstance(m, (nn.Conv2d)): if isinstance(m, (nn.Conv2d)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels

View File

@ -101,12 +101,12 @@ if __name__ == '__main__':
for _, cell in net.cells_and_names(): for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape, cell.weight.shape,
cell.weight.default_input.dtype).to_tensor() cell.weight.dtype)
if isinstance(cell, nn.Dense): if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape, cell.weight.shape,
cell.weight.default_input.dtype).to_tensor() cell.weight.dtype)
# init lr # init lr
if args_opt.net == "resnet50": if args_opt.net == "resnet50":
@ -123,8 +123,14 @@ if __name__ == '__main__':
lr = Tensor(lr) lr = Tensor(lr)
# define opt # define opt
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainable_params())) decayed_params = []
no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params}, {'params': no_decayed_params},
{'order_params': net.trainable_params()}] {'order_params': net.trainable_params()}]

View File

@ -91,12 +91,12 @@ if __name__ == '__main__':
for _, cell in net.cells_and_names(): for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape, cell.weight.shape,
cell.weight.default_input.dtype).to_tensor() cell.weight.dtype)
if isinstance(cell, nn.Dense): if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape, cell.weight.shape,
cell.weight.default_input.dtype).to_tensor() cell.weight.dtype)
if not config.use_label_smooth: if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)

View File

@ -63,19 +63,19 @@ class Resnet(ImageClassificationNetwork):
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer( cell.weight.default_input = init.initializer(
KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() cell.weight.shape, cell.weight.dtype)
elif isinstance(cell, nn.BatchNorm2d): elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.default_input = init.initializer('ones', cell.gamma.default_input.shape).to_tensor() cell.gamma.default_input = init.initializer('ones', cell.gamma.shape)
cell.beta.default_input = init.initializer('zeros', cell.beta.default_input.shape).to_tensor() cell.beta.default_input = init.initializer('zeros', cell.beta.shape)
# Zero-initialize the last BN in each residual branch, # Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity. # so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
for cell in self.cells_and_names(): for cell in self.cells_and_names():
if isinstance(cell, backbones.resnet.Bottleneck): if isinstance(cell, backbones.resnet.Bottleneck):
cell.bn3.gamma.default_input = init.initializer('zeros', cell.bn3.gamma.default_input.shape).to_tensor() cell.bn3.gamma.default_input = init.initializer('zeros', cell.bn3.gamma.shape)
elif isinstance(cell, backbones.resnet.BasicBlock): elif isinstance(cell, backbones.resnet.BasicBlock):
cell.bn2.gamma.default_input = init.initializer('zeros', cell.bn2.gamma.default_input.shape).to_tensor() cell.bn2.gamma.default_input = init.initializer('zeros', cell.bn2.gamma.shape)

View File

@ -19,7 +19,6 @@ import math
from functools import reduce from functools import reduce
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import initializer as init from mindspore.common import initializer as init
def _calculate_gain(nonlinearity, param=None): def _calculate_gain(nonlinearity, param=None):
@ -191,23 +190,25 @@ def default_recurisive_init(custom_cell):
for _, cell in custom_cell.cells_and_names(): for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape, cell.weight.shape,
cell.weight.default_input.dtype).to_tensor() cell.weight.dtype)
if cell.bias is not None: if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
np.random.seed(0) np.random.seed(0)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), cell.bias.default_input = init.initializer(init.Uniform(bound),
cell.bias.default_input.dtype) cell.bias.shape,
cell.bias.dtype)
elif isinstance(cell, nn.Dense): elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape, cell.weight.shape,
cell.weight.default_input.dtype).to_tensor() cell.weight.dtype)
if cell.bias is not None: if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
np.random.seed(0) np.random.seed(0)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), cell.bias.default_input = init.initializer(init.Uniform(bound),
cell.bias.default_input.dtype) cell.bias.shape,
cell.bias.dtype)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass pass

View File

@ -19,7 +19,6 @@ import math
from functools import reduce from functools import reduce
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import initializer as init from mindspore.common import initializer as init
def _calculate_gain(nonlinearity, param=None): def _calculate_gain(nonlinearity, param=None):
@ -191,23 +190,25 @@ def default_recurisive_init(custom_cell):
for _, cell in custom_cell.cells_and_names(): for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape, cell.weight.shape,
cell.weight.default_input.dtype).to_tensor() cell.weight.dtype)
if cell.bias is not None: if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
np.random.seed(0) np.random.seed(0)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), cell.bias.default_input = init.initializer(init.Uniform(bound),
cell.bias.default_input.dtype) cell.bias.shape,
cell.bias.dtype)
elif isinstance(cell, nn.Dense): elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape, cell.weight.shape,
cell.weight.default_input.dtype).to_tensor() cell.weight.dtype)
if cell.bias is not None: if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy()) fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
np.random.seed(0) np.random.seed(0)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), cell.bias.default_input = init.initializer(init.Uniform(bound),
cell.bias.default_input.dtype) cell.bias.shape,
cell.bias.dtype)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass pass

View File

@ -102,16 +102,16 @@ class Vgg(nn.Cell):
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer( cell.weight.default_input = init.initializer(
KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'), KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() cell.weight.shape, cell.weight.dtype)
if cell.bias is not None: if cell.bias is not None:
cell.bias.default_input = init.initializer( cell.bias.default_input = init.initializer(
'zeros', cell.bias.default_input.shape, cell.bias.default_input.dtype).to_tensor() 'zeros', cell.bias.shape, cell.bias.dtype)
elif isinstance(cell, nn.Dense): elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer( cell.weight.default_input = init.initializer(
init.Normal(0.01), cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor() init.Normal(0.01), cell.weight.shape, cell.weight.dtype)
if cell.bias is not None: if cell.bias is not None:
cell.bias.default_input = init.initializer( cell.bias.default_input = init.initializer(
'zeros', cell.bias.default_input.shape, cell.bias.default_input.dtype).to_tensor() 'zeros', cell.bias.shape, cell.bias.dtype)
cfg = { cfg = {

View File

@ -14,11 +14,11 @@
# ============================================================================ # ============================================================================
"""Parameter init.""" """Parameter init."""
import math import math
from functools import reduce
import numpy as np import numpy as np
from mindspore.common import initializer as init from mindspore.common import initializer as init
from mindspore.common.initializer import Initializer as MeInitializer from mindspore.common.initializer import Initializer as MeInitializer
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor
np.random.seed(5) np.random.seed(5)
@ -134,7 +134,7 @@ def _calculate_fan_in_and_fan_out(arr):
num_output_fmaps = arr.shape[0] num_output_fmaps = arr.shape[0]
receptive_field_size = 1 receptive_field_size = 1
if dimensions > 2: if dimensions > 2:
receptive_field_size = arr[0][0].size receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:])
fan_in = num_input_fmaps * receptive_field_size fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size
@ -159,21 +159,23 @@ def default_recurisive_init(custom_cell):
for _, cell in custom_cell.cells_and_names(): for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape, cell.weight.shape,
cell.weight.default_input.dtype).to_tensor() cell.weight.dtype)
if cell.bias is not None: if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), cell.bias.default_input = init.initializer(init.Uniform(bound),
cell.bias.default_input.dtype) cell.bias.shape,
cell.bias.dtype)
elif isinstance(cell, nn.Dense): elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape, cell.weight.shape,
cell.weight.default_input.dtype).to_tensor() cell.weight.dtype)
if cell.bias is not None: if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape), cell.bias.default_input = init.initializer(init.Uniform(bound),
cell.bias.default_input.dtype) cell.bias.shape,
cell.bias.dtype)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass pass

View File

@ -58,7 +58,7 @@ def load_backbone(net, ckpt_path, args):
darknet_backbone_prefix = 'network.backbone' darknet_backbone_prefix = 'network.backbone'
find_param = [] find_param = []
not_found_param = [] not_found_param = []
net.init_parameters_data()
for name, cell in net.cells_and_names(): for name, cell in net.cells_and_names():
if name.startswith(yolo_backbone_prefix): if name.startswith(yolo_backbone_prefix):
name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix) name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix)