!21154 Add resnet lars and config.

Merge pull request !21154 from linqingke/resnet
This commit is contained in:
i-robot 2021-08-02 10:12:32 +00:00 committed by Gitee
commit 076aa59dab
16 changed files with 127 additions and 54 deletions

View File

@ -133,56 +133,62 @@ class ParameterProcess:
if isinstance(origin_params_copy[0], Parameter):
group_params = [{"params": parameters}]
else:
group_params = []
params_name = [param.name for param in parameters]
new_params_count = copy.deepcopy(params_name)
new_params_clone = {}
max_key_number = 0
for group_param in origin_params_copy:
if 'order_params' in group_param.keys():
new_group_param = copy.deepcopy(group_param)
new_group_param['order_params'] = parameters
group_params.append(new_group_param)
continue
params_value = []
for param in group_param['params']:
if param.name in params_name:
index = params_name.index(param.name)
params_value.append(parameters[index])
new_params_count.remove(param.name)
return group_params
group_params = []
params_name = [param.name for param in parameters]
new_params_count = copy.deepcopy(params_name)
new_params_clone = {}
max_key_number = 0
for group_param in origin_params_copy:
if 'order_params' in group_param.keys():
new_group_param = copy.deepcopy(group_param)
new_group_param['params'] = params_value
new_group_param['order_params'] = parameters
group_params.append(new_group_param)
if len(group_param.keys()) > max_key_number:
max_key_number = len(group_param.keys())
new_params_clone = copy.deepcopy(group_param)
if new_params_count:
params_value = []
for param in new_params_count:
index = params_name.index(param)
continue
params_value = []
for param in group_param['params']:
if param.name in params_name:
index = params_name.index(param.name)
params_value.append(parameters[index])
if new_params_clone:
new_params_clone['params'] = params_value
group_params.append(new_params_clone)
else:
group_params.append({"params": params_value})
new_params_count.remove(param.name)
new_group_param = copy.deepcopy(group_param)
new_group_param['params'] = params_value
group_params.append(new_group_param)
if len(group_param.keys()) > max_key_number:
max_key_number = len(group_param.keys())
new_params_clone = copy.deepcopy(group_param)
if new_params_count:
params_value = []
for param in new_params_count:
index = params_name.index(param)
params_value.append(parameters[index])
if new_params_clone:
new_params_clone['params'] = params_value
group_params.append(new_params_clone)
else:
group_params.append({"params": params_value})
return group_params
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
"""Apply gradient accumulation to cumulative grad."""
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
@_gradient_clear_op.register("Tensor")
def _clear_grad(cumulative_grad):
zero_grad = P.ZerosLike()(cumulative_grad)
return F.assign(cumulative_grad, zero_grad)
class GradientAccumulation(Cell):
"""
After accumulating the gradients of multiple steps, call to optimize its update.

View File

@ -243,6 +243,7 @@ class GradientFreeze:
return network, optimizer
def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None,
max_accumulation_step=1):
"""Provide freeze network cell."""

View File

@ -81,6 +81,7 @@ class CommonHeadLastFN(Cell):
x = self.multiplier * x
return x
class LessBN(Cell):
"""
Reduce the number of BN automatically to improve the network performance

View File

@ -33,6 +33,8 @@ lr_decay_mode: "cosine"
use_label_smooth: True
label_smooth_factor: 0.1
lr: 0.1
lars_epsilon: 0.0
lars_coefficient: 0.001
net_name: "resnet101"
dataset: "imagenet2012"
@ -49,6 +51,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
all_reduce_fusion_config:
- 2
- 60

View File

@ -33,6 +33,8 @@ lr_decay_mode: "poly"
lr_init: 0.01
lr_end: 0.00001
lr_max: 0.1
lars_epsilon: 0.0
lars_coefficient: 0.001
net_name: "resnet18"
dataset: "cifar10"
@ -49,6 +51,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
# Export options
device_id: 0

View File

@ -35,6 +35,8 @@ label_smooth_factor: 0.1
lr_init: 0
lr_max: 0.8
lr_end: 0.0
lars_epsilon: 0.0
lars_coefficient: 0.001
net_name: "resnet18"
dataset: "imagenet2012"
@ -51,6 +53,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
# Export options
device_id: 0

View File

@ -35,6 +35,8 @@ label_smooth_factor: 0.1
lr_init: 0
lr_max: 0.8
lr_end: 0.0
lars_epsilon: 0.0
lars_coefficient: 0.001
net_name: "resnet34"
dataset: "imagenet2012"
@ -51,6 +53,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
# Export options
device_id: 0

View File

@ -33,6 +33,8 @@ lr_decay_mode: "poly"
lr_init: 0.01
lr_end: 0.00001
lr_max: 0.1
lars_epsilon: 0.0
lars_coefficient: 0.001
net_name: "resnet50"
dataset: "cifar10"
@ -49,6 +51,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
all_reduce_fusion_config:
- 2
- 115

View File

@ -35,6 +35,8 @@ label_smooth_factor: 0.1
lr_init: 0
lr_max: 0.8
lr_end: 0.0
lars_epsilon: 0.0
lars_coefficient: 0.001
net_name: "resnet50"
dataset: "imagenet2012"
@ -51,6 +53,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O1"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
all_reduce_fusion_config:
- 85
- 160

View File

@ -33,6 +33,8 @@ label_smooth_factor: 0.1
lr_init: 0.05803
lr_decay: 4.04839
lr_end_epoch: 53
lars_epsilon: 0.0
lars_coefficient: 0.001
damping_init: 0.02714
damping_decay: 0.50036
frequency: 834
@ -52,6 +54,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
all_reduce_fusion_config:
- 85
- 160

View File

@ -33,6 +33,8 @@ label_smooth_factor: 0.1
lr_init: 0.05672
lr_decay: 4.9687
lr_end_epoch: 50
lars_epsilon: 0.0
lars_coefficient: 0.001
damping_init: 0.02345
damping_decay: 0.5467
frequency: 834
@ -52,6 +54,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
all_reduce_fusion_config:
- 85
- 160

View File

@ -35,6 +35,8 @@ label_smooth_factor: 0.1
lr_init: 0
lr_max: 0.8
lr_end: 0.0
lars_epsilon: 0.0
lars_coefficient: 0.001
net_name: "resnet50"
dataset: "imagenet2012"
@ -51,6 +53,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
all_reduce_fusion_config:
- 85
- 160

View File

@ -26,6 +26,8 @@ save_ckpt: False
mode_name: "GRAPH"
dtype: "fp16"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
# Export options
device_id: 0

View File

@ -36,6 +36,8 @@ label_smooth_factor: 0.1
lr_init: 0
lr_end: 0.0001
lr_max: 0.3
lars_epsilon: 0.0
lars_coefficient: 0.001
net_name: "se-resnet50"
dataset: "imagenet2012"
@ -52,6 +54,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
all_reduce_fusion_config:
- 1
- 100

View File

@ -23,7 +23,7 @@ from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size
scale = 1.0
scale /= max(1., fan_in)
@ -108,7 +108,7 @@ def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'
def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
@ -121,7 +121,7 @@ def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
@ -134,7 +134,7 @@ def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
@ -207,7 +207,7 @@ class ResidualBlock(nn.Cell):
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
self.bn3 = _bn_last(out_channel)
self.bn3 = _bn(out_channel)
if self.se_block:
self.se_global_pool = P.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)

View File

@ -14,9 +14,10 @@
# ============================================================================
"""train resnet."""
import os
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim import Momentum, thor
from mindspore.nn.optim import Momentum, thor, LARS
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.train_thor import ConvertModelUtils
@ -37,6 +38,7 @@ from src.metric import DistAccuracy, ClassifyCorrectCell
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_rank_id, get_device_num
from src.resnet import conv_variance_scaling_initializer
set_seed(1)
@ -130,13 +132,26 @@ def init_weight(net):
else:
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
cell.weight.shape,
cell.weight.dtype))
if config.conv_init == "XavierUniform":
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
cell.weight.shape,
cell.weight.dtype))
elif config.conv_init == "TruncatedNormal":
weight = conv_variance_scaling_initializer(cell.in_channels,
cell.out_channels,
cell.kernel_size[0])
cell.weight.set_data(weight)
if isinstance(cell, nn.Dense):
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.shape,
cell.weight.dtype))
if config.dense_init == "TruncatedNormal":
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.shape,
cell.weight.dtype))
elif config.dense_init == "RandomNormal":
in_channel = cell.in_channels
out_channel = cell.out_channels
weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=cell.weight.dtype)
cell.weight.set_data(weight)
def init_lr(step_size):
"""init lr"""
@ -163,6 +178,21 @@ def init_loss_scale():
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
return loss
def init_group_params(net):
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
return group_params
def run_eval(target, model, ckpt_save_dir, cb):
"""run_eval"""
if config.run_eval:
@ -205,18 +235,11 @@ def train_net():
init_weight(net=net)
lr = Tensor(init_lr(step_size=step_size))
# define opt
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
group_params = init_group_params(net)
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
if config.optimizer == "LARS":
opt = LARS(opt, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient,
lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name)
loss = init_loss_scale()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
dist_eval_network = ClassifyCorrectCell(net) if config.run_distribute else None