forked from mindspore-Ecosystem/mindspore
!21154 Add resnet lars and config.
Merge pull request !21154 from linqingke/resnet
This commit is contained in:
commit
076aa59dab
|
@ -133,56 +133,62 @@ class ParameterProcess:
|
|||
|
||||
if isinstance(origin_params_copy[0], Parameter):
|
||||
group_params = [{"params": parameters}]
|
||||
else:
|
||||
group_params = []
|
||||
params_name = [param.name for param in parameters]
|
||||
new_params_count = copy.deepcopy(params_name)
|
||||
new_params_clone = {}
|
||||
max_key_number = 0
|
||||
for group_param in origin_params_copy:
|
||||
if 'order_params' in group_param.keys():
|
||||
new_group_param = copy.deepcopy(group_param)
|
||||
new_group_param['order_params'] = parameters
|
||||
group_params.append(new_group_param)
|
||||
continue
|
||||
params_value = []
|
||||
for param in group_param['params']:
|
||||
if param.name in params_name:
|
||||
index = params_name.index(param.name)
|
||||
params_value.append(parameters[index])
|
||||
new_params_count.remove(param.name)
|
||||
return group_params
|
||||
|
||||
group_params = []
|
||||
params_name = [param.name for param in parameters]
|
||||
new_params_count = copy.deepcopy(params_name)
|
||||
new_params_clone = {}
|
||||
max_key_number = 0
|
||||
for group_param in origin_params_copy:
|
||||
if 'order_params' in group_param.keys():
|
||||
new_group_param = copy.deepcopy(group_param)
|
||||
new_group_param['params'] = params_value
|
||||
new_group_param['order_params'] = parameters
|
||||
group_params.append(new_group_param)
|
||||
if len(group_param.keys()) > max_key_number:
|
||||
max_key_number = len(group_param.keys())
|
||||
new_params_clone = copy.deepcopy(group_param)
|
||||
if new_params_count:
|
||||
params_value = []
|
||||
for param in new_params_count:
|
||||
index = params_name.index(param)
|
||||
continue
|
||||
params_value = []
|
||||
for param in group_param['params']:
|
||||
if param.name in params_name:
|
||||
index = params_name.index(param.name)
|
||||
params_value.append(parameters[index])
|
||||
if new_params_clone:
|
||||
new_params_clone['params'] = params_value
|
||||
group_params.append(new_params_clone)
|
||||
else:
|
||||
group_params.append({"params": params_value})
|
||||
new_params_count.remove(param.name)
|
||||
new_group_param = copy.deepcopy(group_param)
|
||||
new_group_param['params'] = params_value
|
||||
group_params.append(new_group_param)
|
||||
if len(group_param.keys()) > max_key_number:
|
||||
max_key_number = len(group_param.keys())
|
||||
new_params_clone = copy.deepcopy(group_param)
|
||||
if new_params_count:
|
||||
params_value = []
|
||||
for param in new_params_count:
|
||||
index = params_name.index(param)
|
||||
params_value.append(parameters[index])
|
||||
if new_params_clone:
|
||||
new_params_clone['params'] = params_value
|
||||
group_params.append(new_params_clone)
|
||||
else:
|
||||
group_params.append({"params": params_value})
|
||||
return group_params
|
||||
|
||||
|
||||
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
|
||||
|
||||
|
||||
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
|
||||
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
|
||||
"""Apply gradient accumulation to cumulative grad."""
|
||||
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
|
||||
|
||||
|
||||
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
|
||||
|
||||
|
||||
@_gradient_clear_op.register("Tensor")
|
||||
def _clear_grad(cumulative_grad):
|
||||
zero_grad = P.ZerosLike()(cumulative_grad)
|
||||
return F.assign(cumulative_grad, zero_grad)
|
||||
|
||||
|
||||
class GradientAccumulation(Cell):
|
||||
"""
|
||||
After accumulating the gradients of multiple steps, call to optimize its update.
|
||||
|
|
|
@ -243,6 +243,7 @@ class GradientFreeze:
|
|||
|
||||
return network, optimizer
|
||||
|
||||
|
||||
def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None,
|
||||
max_accumulation_step=1):
|
||||
"""Provide freeze network cell."""
|
||||
|
|
|
@ -81,6 +81,7 @@ class CommonHeadLastFN(Cell):
|
|||
x = self.multiplier * x
|
||||
return x
|
||||
|
||||
|
||||
class LessBN(Cell):
|
||||
"""
|
||||
Reduce the number of BN automatically to improve the network performance
|
||||
|
|
|
@ -33,6 +33,8 @@ lr_decay_mode: "cosine"
|
|||
use_label_smooth: True
|
||||
label_smooth_factor: 0.1
|
||||
lr: 0.1
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
|
||||
net_name: "resnet101"
|
||||
dataset: "imagenet2012"
|
||||
|
@ -49,6 +51,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
all_reduce_fusion_config:
|
||||
- 2
|
||||
- 60
|
||||
|
|
|
@ -33,6 +33,8 @@ lr_decay_mode: "poly"
|
|||
lr_init: 0.01
|
||||
lr_end: 0.00001
|
||||
lr_max: 0.1
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
|
||||
net_name: "resnet18"
|
||||
dataset: "cifar10"
|
||||
|
@ -49,6 +51,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
|
||||
# Export options
|
||||
device_id: 0
|
||||
|
|
|
@ -35,6 +35,8 @@ label_smooth_factor: 0.1
|
|||
lr_init: 0
|
||||
lr_max: 0.8
|
||||
lr_end: 0.0
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
|
||||
net_name: "resnet18"
|
||||
dataset: "imagenet2012"
|
||||
|
@ -51,6 +53,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
|
||||
# Export options
|
||||
device_id: 0
|
||||
|
|
|
@ -35,6 +35,8 @@ label_smooth_factor: 0.1
|
|||
lr_init: 0
|
||||
lr_max: 0.8
|
||||
lr_end: 0.0
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
|
||||
net_name: "resnet34"
|
||||
dataset: "imagenet2012"
|
||||
|
@ -51,6 +53,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
|
||||
# Export options
|
||||
device_id: 0
|
||||
|
|
|
@ -33,6 +33,8 @@ lr_decay_mode: "poly"
|
|||
lr_init: 0.01
|
||||
lr_end: 0.00001
|
||||
lr_max: 0.1
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
|
||||
net_name: "resnet50"
|
||||
dataset: "cifar10"
|
||||
|
@ -49,6 +51,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
all_reduce_fusion_config:
|
||||
- 2
|
||||
- 115
|
||||
|
|
|
@ -35,6 +35,8 @@ label_smooth_factor: 0.1
|
|||
lr_init: 0
|
||||
lr_max: 0.8
|
||||
lr_end: 0.0
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
|
||||
net_name: "resnet50"
|
||||
dataset: "imagenet2012"
|
||||
|
@ -51,6 +53,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O1"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
all_reduce_fusion_config:
|
||||
- 85
|
||||
- 160
|
||||
|
|
|
@ -33,6 +33,8 @@ label_smooth_factor: 0.1
|
|||
lr_init: 0.05803
|
||||
lr_decay: 4.04839
|
||||
lr_end_epoch: 53
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
damping_init: 0.02714
|
||||
damping_decay: 0.50036
|
||||
frequency: 834
|
||||
|
@ -52,6 +54,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
all_reduce_fusion_config:
|
||||
- 85
|
||||
- 160
|
||||
|
|
|
@ -33,6 +33,8 @@ label_smooth_factor: 0.1
|
|||
lr_init: 0.05672
|
||||
lr_decay: 4.9687
|
||||
lr_end_epoch: 50
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
damping_init: 0.02345
|
||||
damping_decay: 0.5467
|
||||
frequency: 834
|
||||
|
@ -52,6 +54,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
all_reduce_fusion_config:
|
||||
- 85
|
||||
- 160
|
||||
|
|
|
@ -35,6 +35,8 @@ label_smooth_factor: 0.1
|
|||
lr_init: 0
|
||||
lr_max: 0.8
|
||||
lr_end: 0.0
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
|
||||
net_name: "resnet50"
|
||||
dataset: "imagenet2012"
|
||||
|
@ -51,6 +53,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
all_reduce_fusion_config:
|
||||
- 85
|
||||
- 160
|
||||
|
|
|
@ -26,6 +26,8 @@ save_ckpt: False
|
|||
mode_name: "GRAPH"
|
||||
dtype: "fp16"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
|
||||
# Export options
|
||||
device_id: 0
|
||||
|
|
|
@ -36,6 +36,8 @@ label_smooth_factor: 0.1
|
|||
lr_init: 0
|
||||
lr_end: 0.0001
|
||||
lr_max: 0.3
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
|
||||
net_name: "se-resnet50"
|
||||
dataset: "imagenet2012"
|
||||
|
@ -52,6 +54,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
all_reduce_fusion_config:
|
||||
- 1
|
||||
- 100
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
|
||||
def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
|
||||
fan_in = in_channel * kernel_size * kernel_size
|
||||
scale = 1.0
|
||||
scale /= max(1., fan_in)
|
||||
|
@ -108,7 +108,7 @@ def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'
|
|||
|
||||
def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||
if use_se:
|
||||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
|
||||
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
|
@ -121,7 +121,7 @@ def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
|||
|
||||
def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||
if use_se:
|
||||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
|
||||
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
|
@ -134,7 +134,7 @@ def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
|||
|
||||
def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||
if use_se:
|
||||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
|
||||
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
|
@ -207,7 +207,7 @@ class ResidualBlock(nn.Cell):
|
|||
self.bn2 = _bn(channel)
|
||||
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
|
||||
self.bn3 = _bn_last(out_channel)
|
||||
self.bn3 = _bn(out_channel)
|
||||
if self.se_block:
|
||||
self.se_global_pool = P.ReduceMean(keep_dims=False)
|
||||
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
# ============================================================================
|
||||
"""train resnet."""
|
||||
import os
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim import Momentum, thor
|
||||
from mindspore.nn.optim import Momentum, thor, LARS
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.train_thor import ConvertModelUtils
|
||||
|
@ -37,6 +38,7 @@ from src.metric import DistAccuracy, ClassifyCorrectCell
|
|||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_rank_id, get_device_num
|
||||
from src.resnet import conv_variance_scaling_initializer
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
@ -130,13 +132,26 @@ def init_weight(net):
|
|||
else:
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
if config.conv_init == "XavierUniform":
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
elif config.conv_init == "TruncatedNormal":
|
||||
weight = conv_variance_scaling_initializer(cell.in_channels,
|
||||
cell.out_channels,
|
||||
cell.kernel_size[0])
|
||||
cell.weight.set_data(weight)
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
if config.dense_init == "TruncatedNormal":
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
elif config.dense_init == "RandomNormal":
|
||||
in_channel = cell.in_channels
|
||||
out_channel = cell.out_channels
|
||||
weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
|
||||
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=cell.weight.dtype)
|
||||
cell.weight.set_data(weight)
|
||||
|
||||
def init_lr(step_size):
|
||||
"""init lr"""
|
||||
|
@ -163,6 +178,21 @@ def init_loss_scale():
|
|||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
return loss
|
||||
|
||||
|
||||
def init_group_params(net):
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
|
||||
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
return group_params
|
||||
|
||||
def run_eval(target, model, ckpt_save_dir, cb):
|
||||
"""run_eval"""
|
||||
if config.run_eval:
|
||||
|
@ -205,18 +235,11 @@ def train_net():
|
|||
init_weight(net=net)
|
||||
lr = Tensor(init_lr(step_size=step_size))
|
||||
# define opt
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
|
||||
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
group_params = init_group_params(net)
|
||||
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
||||
if config.optimizer == "LARS":
|
||||
opt = LARS(opt, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient,
|
||||
lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name)
|
||||
loss = init_loss_scale()
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
dist_eval_network = ClassifyCorrectCell(net) if config.run_distribute else None
|
||||
|
|
Loading…
Reference in New Issue