From 0beb46325b73572352d687246200dd7e8447cbab Mon Sep 17 00:00:00 2001 From: linqingke Date: Fri, 30 Jul 2021 17:53:58 +0800 Subject: [PATCH] add resnet 32k acc. --- mindspore/nn/acc/base.py | 68 ++++++++++--------- mindspore/nn/acc/grad_freeze.py | 1 + mindspore/nn/acc/less_batch_normalization.py | 1 + .../resnet/resnet101_imagenet2012_config.yaml | 4 ++ .../cv/resnet/resnet18_cifar10_config.yaml | 4 ++ .../resnet/resnet18_imagenet2012_config.yaml | 4 ++ .../resnet/resnet34_imagenet2012_config.yaml | 4 ++ .../cv/resnet/resnet50_cifar10_config.yaml | 4 ++ .../resnet50_imagenet2012_Acc_config.yaml | 4 ++ ...net50_imagenet2012_Ascend_Thor_config.yaml | 4 ++ ...resnet50_imagenet2012_GPU_Thor_config.yaml | 4 ++ .../resnet/resnet50_imagenet2012_config.yaml | 4 ++ .../cv/resnet/resnet_benchmark_GPU.yaml | 2 + .../se-resnet50_imagenet2012_config.yaml | 4 ++ model_zoo/official/cv/resnet/src/resnet.py | 10 +-- model_zoo/official/cv/resnet/train.py | 59 +++++++++++----- 16 files changed, 127 insertions(+), 54 deletions(-) diff --git a/mindspore/nn/acc/base.py b/mindspore/nn/acc/base.py index a0be25582d6..b8c5587d7c3 100644 --- a/mindspore/nn/acc/base.py +++ b/mindspore/nn/acc/base.py @@ -133,56 +133,62 @@ class ParameterProcess: if isinstance(origin_params_copy[0], Parameter): group_params = [{"params": parameters}] - else: - group_params = [] - params_name = [param.name for param in parameters] - new_params_count = copy.deepcopy(params_name) - new_params_clone = {} - max_key_number = 0 - for group_param in origin_params_copy: - if 'order_params' in group_param.keys(): - new_group_param = copy.deepcopy(group_param) - new_group_param['order_params'] = parameters - group_params.append(new_group_param) - continue - params_value = [] - for param in group_param['params']: - if param.name in params_name: - index = params_name.index(param.name) - params_value.append(parameters[index]) - new_params_count.remove(param.name) + return group_params + + group_params = [] + params_name = [param.name for param in parameters] + new_params_count = copy.deepcopy(params_name) + new_params_clone = {} + max_key_number = 0 + for group_param in origin_params_copy: + if 'order_params' in group_param.keys(): new_group_param = copy.deepcopy(group_param) - new_group_param['params'] = params_value + new_group_param['order_params'] = parameters group_params.append(new_group_param) - if len(group_param.keys()) > max_key_number: - max_key_number = len(group_param.keys()) - new_params_clone = copy.deepcopy(group_param) - if new_params_count: - params_value = [] - for param in new_params_count: - index = params_name.index(param) + continue + params_value = [] + for param in group_param['params']: + if param.name in params_name: + index = params_name.index(param.name) params_value.append(parameters[index]) - if new_params_clone: - new_params_clone['params'] = params_value - group_params.append(new_params_clone) - else: - group_params.append({"params": params_value}) + new_params_count.remove(param.name) + new_group_param = copy.deepcopy(group_param) + new_group_param['params'] = params_value + group_params.append(new_group_param) + if len(group_param.keys()) > max_key_number: + max_key_number = len(group_param.keys()) + new_params_clone = copy.deepcopy(group_param) + if new_params_count: + params_value = [] + for param in new_params_count: + index = params_name.index(param) + params_value.append(parameters[index]) + if new_params_clone: + new_params_clone['params'] = params_value + group_params.append(new_params_clone) + else: + group_params.append({"params": params_value}) return group_params + _gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op") + @_gradient_accumulation_op.register("Int64", "Tensor", "Tensor") def _cumulative_grad(accumulation_step, cumulative_grad, grad): """Apply gradient accumulation to cumulative grad.""" return P.AssignAdd()(cumulative_grad, grad / accumulation_step) + _gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op") + @_gradient_clear_op.register("Tensor") def _clear_grad(cumulative_grad): zero_grad = P.ZerosLike()(cumulative_grad) return F.assign(cumulative_grad, zero_grad) + class GradientAccumulation(Cell): """ After accumulating the gradients of multiple steps, call to optimize its update. diff --git a/mindspore/nn/acc/grad_freeze.py b/mindspore/nn/acc/grad_freeze.py index dd8835953ec..8e84d4f12ab 100644 --- a/mindspore/nn/acc/grad_freeze.py +++ b/mindspore/nn/acc/grad_freeze.py @@ -243,6 +243,7 @@ class GradientFreeze: return network, optimizer + def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None, max_accumulation_step=1): """Provide freeze network cell.""" diff --git a/mindspore/nn/acc/less_batch_normalization.py b/mindspore/nn/acc/less_batch_normalization.py index c2c6683afef..d1d35b4a94d 100644 --- a/mindspore/nn/acc/less_batch_normalization.py +++ b/mindspore/nn/acc/less_batch_normalization.py @@ -81,6 +81,7 @@ class CommonHeadLastFN(Cell): x = self.multiplier * x return x + class LessBN(Cell): """ Reduce the number of BN automatically to improve the network performance diff --git a/model_zoo/official/cv/resnet/resnet101_imagenet2012_config.yaml b/model_zoo/official/cv/resnet/resnet101_imagenet2012_config.yaml index 0ce8e0161d0..3516107b267 100644 --- a/model_zoo/official/cv/resnet/resnet101_imagenet2012_config.yaml +++ b/model_zoo/official/cv/resnet/resnet101_imagenet2012_config.yaml @@ -33,6 +33,8 @@ lr_decay_mode: "cosine" use_label_smooth: True label_smooth_factor: 0.1 lr: 0.1 +lars_epsilon: 0.0 +lars_coefficient: 0.001 net_name: "resnet101" dataset: "imagenet2012" @@ -49,6 +51,8 @@ enable_cache: False cache_session_id: "" mode_name: "GRAPH" acc_mode: "O0" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" all_reduce_fusion_config: - 2 - 60 diff --git a/model_zoo/official/cv/resnet/resnet18_cifar10_config.yaml b/model_zoo/official/cv/resnet/resnet18_cifar10_config.yaml index e164bffd506..d1113625aaa 100644 --- a/model_zoo/official/cv/resnet/resnet18_cifar10_config.yaml +++ b/model_zoo/official/cv/resnet/resnet18_cifar10_config.yaml @@ -33,6 +33,8 @@ lr_decay_mode: "poly" lr_init: 0.01 lr_end: 0.00001 lr_max: 0.1 +lars_epsilon: 0.0 +lars_coefficient: 0.001 net_name: "resnet18" dataset: "cifar10" @@ -49,6 +51,8 @@ enable_cache: False cache_session_id: "" mode_name: "GRAPH" acc_mode: "O0" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" # Export options device_id: 0 diff --git a/model_zoo/official/cv/resnet/resnet18_imagenet2012_config.yaml b/model_zoo/official/cv/resnet/resnet18_imagenet2012_config.yaml index 92c66f238a2..d52d7fdf531 100644 --- a/model_zoo/official/cv/resnet/resnet18_imagenet2012_config.yaml +++ b/model_zoo/official/cv/resnet/resnet18_imagenet2012_config.yaml @@ -35,6 +35,8 @@ label_smooth_factor: 0.1 lr_init: 0 lr_max: 0.8 lr_end: 0.0 +lars_epsilon: 0.0 +lars_coefficient: 0.001 net_name: "resnet18" dataset: "imagenet2012" @@ -51,6 +53,8 @@ enable_cache: False cache_session_id: "" mode_name: "GRAPH" acc_mode: "O0" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" # Export options device_id: 0 diff --git a/model_zoo/official/cv/resnet/resnet34_imagenet2012_config.yaml b/model_zoo/official/cv/resnet/resnet34_imagenet2012_config.yaml index 5b4b0493dfa..8e7d79ff503 100644 --- a/model_zoo/official/cv/resnet/resnet34_imagenet2012_config.yaml +++ b/model_zoo/official/cv/resnet/resnet34_imagenet2012_config.yaml @@ -35,6 +35,8 @@ label_smooth_factor: 0.1 lr_init: 0 lr_max: 0.8 lr_end: 0.0 +lars_epsilon: 0.0 +lars_coefficient: 0.001 net_name: "resnet34" dataset: "imagenet2012" @@ -51,6 +53,8 @@ enable_cache: False cache_session_id: "" mode_name: "GRAPH" acc_mode: "O0" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" # Export options device_id: 0 diff --git a/model_zoo/official/cv/resnet/resnet50_cifar10_config.yaml b/model_zoo/official/cv/resnet/resnet50_cifar10_config.yaml index 51021bb5a39..fe0995baa9b 100644 --- a/model_zoo/official/cv/resnet/resnet50_cifar10_config.yaml +++ b/model_zoo/official/cv/resnet/resnet50_cifar10_config.yaml @@ -33,6 +33,8 @@ lr_decay_mode: "poly" lr_init: 0.01 lr_end: 0.00001 lr_max: 0.1 +lars_epsilon: 0.0 +lars_coefficient: 0.001 net_name: "resnet50" dataset: "cifar10" @@ -49,6 +51,8 @@ enable_cache: False cache_session_id: "" mode_name: "GRAPH" acc_mode: "O0" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" all_reduce_fusion_config: - 2 - 115 diff --git a/model_zoo/official/cv/resnet/resnet50_imagenet2012_Acc_config.yaml b/model_zoo/official/cv/resnet/resnet50_imagenet2012_Acc_config.yaml index 80456c685db..5563840782e 100644 --- a/model_zoo/official/cv/resnet/resnet50_imagenet2012_Acc_config.yaml +++ b/model_zoo/official/cv/resnet/resnet50_imagenet2012_Acc_config.yaml @@ -35,6 +35,8 @@ label_smooth_factor: 0.1 lr_init: 0 lr_max: 0.8 lr_end: 0.0 +lars_epsilon: 0.0 +lars_coefficient: 0.001 net_name: "resnet50" dataset: "imagenet2012" @@ -51,6 +53,8 @@ enable_cache: False cache_session_id: "" mode_name: "GRAPH" acc_mode: "O1" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" all_reduce_fusion_config: - 85 - 160 diff --git a/model_zoo/official/cv/resnet/resnet50_imagenet2012_Ascend_Thor_config.yaml b/model_zoo/official/cv/resnet/resnet50_imagenet2012_Ascend_Thor_config.yaml index 2b730eb81ff..0b4f48f33cb 100644 --- a/model_zoo/official/cv/resnet/resnet50_imagenet2012_Ascend_Thor_config.yaml +++ b/model_zoo/official/cv/resnet/resnet50_imagenet2012_Ascend_Thor_config.yaml @@ -33,6 +33,8 @@ label_smooth_factor: 0.1 lr_init: 0.05803 lr_decay: 4.04839 lr_end_epoch: 53 +lars_epsilon: 0.0 +lars_coefficient: 0.001 damping_init: 0.02714 damping_decay: 0.50036 frequency: 834 @@ -52,6 +54,8 @@ enable_cache: False cache_session_id: "" mode_name: "GRAPH" acc_mode: "O0" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" all_reduce_fusion_config: - 85 - 160 diff --git a/model_zoo/official/cv/resnet/resnet50_imagenet2012_GPU_Thor_config.yaml b/model_zoo/official/cv/resnet/resnet50_imagenet2012_GPU_Thor_config.yaml index dd4b492f7e3..94661d85aa9 100644 --- a/model_zoo/official/cv/resnet/resnet50_imagenet2012_GPU_Thor_config.yaml +++ b/model_zoo/official/cv/resnet/resnet50_imagenet2012_GPU_Thor_config.yaml @@ -33,6 +33,8 @@ label_smooth_factor: 0.1 lr_init: 0.05672 lr_decay: 4.9687 lr_end_epoch: 50 +lars_epsilon: 0.0 +lars_coefficient: 0.001 damping_init: 0.02345 damping_decay: 0.5467 frequency: 834 @@ -52,6 +54,8 @@ enable_cache: False cache_session_id: "" mode_name: "GRAPH" acc_mode: "O0" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" all_reduce_fusion_config: - 85 - 160 diff --git a/model_zoo/official/cv/resnet/resnet50_imagenet2012_config.yaml b/model_zoo/official/cv/resnet/resnet50_imagenet2012_config.yaml index a9873711004..89938e3b496 100644 --- a/model_zoo/official/cv/resnet/resnet50_imagenet2012_config.yaml +++ b/model_zoo/official/cv/resnet/resnet50_imagenet2012_config.yaml @@ -35,6 +35,8 @@ label_smooth_factor: 0.1 lr_init: 0 lr_max: 0.8 lr_end: 0.0 +lars_epsilon: 0.0 +lars_coefficient: 0.001 net_name: "resnet50" dataset: "imagenet2012" @@ -51,6 +53,8 @@ enable_cache: False cache_session_id: "" mode_name: "GRAPH" acc_mode: "O0" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" all_reduce_fusion_config: - 85 - 160 diff --git a/model_zoo/official/cv/resnet/resnet_benchmark_GPU.yaml b/model_zoo/official/cv/resnet/resnet_benchmark_GPU.yaml index fb4d21e54c9..9491769f4b8 100644 --- a/model_zoo/official/cv/resnet/resnet_benchmark_GPU.yaml +++ b/model_zoo/official/cv/resnet/resnet_benchmark_GPU.yaml @@ -26,6 +26,8 @@ save_ckpt: False mode_name: "GRAPH" dtype: "fp16" acc_mode: "O0" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" # Export options device_id: 0 diff --git a/model_zoo/official/cv/resnet/se-resnet50_imagenet2012_config.yaml b/model_zoo/official/cv/resnet/se-resnet50_imagenet2012_config.yaml index 7d98865ddc9..9b59225cb65 100644 --- a/model_zoo/official/cv/resnet/se-resnet50_imagenet2012_config.yaml +++ b/model_zoo/official/cv/resnet/se-resnet50_imagenet2012_config.yaml @@ -36,6 +36,8 @@ label_smooth_factor: 0.1 lr_init: 0 lr_end: 0.0001 lr_max: 0.3 +lars_epsilon: 0.0 +lars_coefficient: 0.001 net_name: "se-resnet50" dataset: "imagenet2012" @@ -52,6 +54,8 @@ enable_cache: False cache_session_id: "" mode_name: "GRAPH" acc_mode: "O0" +conv_init: "XavierUniform" +dense_init: "TruncatedNormal" all_reduce_fusion_config: - 1 - 100 diff --git a/model_zoo/official/cv/resnet/src/resnet.py b/model_zoo/official/cv/resnet/src/resnet.py index 0405e38cafa..54174d4ad7d 100755 --- a/model_zoo/official/cv/resnet/src/resnet.py +++ b/model_zoo/official/cv/resnet/src/resnet.py @@ -23,7 +23,7 @@ from mindspore.ops import functional as F from mindspore.common.tensor import Tensor -def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): +def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): fan_in = in_channel * kernel_size * kernel_size scale = 1.0 scale /= max(1., fan_in) @@ -108,7 +108,7 @@ def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu' def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False): if use_se: - weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3) + weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3) else: weight_shape = (out_channel, in_channel, 3, 3) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) @@ -121,7 +121,7 @@ def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False): def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False): if use_se: - weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1) + weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1) else: weight_shape = (out_channel, in_channel, 1, 1) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) @@ -134,7 +134,7 @@ def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False): def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False): if use_se: - weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7) + weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7) else: weight_shape = (out_channel, in_channel, 7, 7) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) @@ -207,7 +207,7 @@ class ResidualBlock(nn.Cell): self.bn2 = _bn(channel) self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se) - self.bn3 = _bn_last(out_channel) + self.bn3 = _bn(out_channel) if self.se_block: self.se_global_pool = P.ReduceMean(keep_dims=False) self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se) diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index e1440dd65e0..7048543b7ac 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -14,9 +14,10 @@ # ============================================================================ """train resnet.""" import os +import numpy as np from mindspore import context from mindspore import Tensor -from mindspore.nn.optim import Momentum, thor +from mindspore.nn.optim import Momentum, thor, LARS from mindspore.train.model import Model from mindspore.context import ParallelMode from mindspore.train.train_thor import ConvertModelUtils @@ -37,6 +38,7 @@ from src.metric import DistAccuracy, ClassifyCorrectCell from src.model_utils.config import config from src.model_utils.moxing_adapter import moxing_wrapper from src.model_utils.device_adapter import get_rank_id, get_device_num +from src.resnet import conv_variance_scaling_initializer set_seed(1) @@ -130,13 +132,26 @@ def init_weight(net): else: for _, cell in net.cells_and_names(): if isinstance(cell, nn.Conv2d): - cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(), - cell.weight.shape, - cell.weight.dtype)) + if config.conv_init == "XavierUniform": + cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(), + cell.weight.shape, + cell.weight.dtype)) + elif config.conv_init == "TruncatedNormal": + weight = conv_variance_scaling_initializer(cell.in_channels, + cell.out_channels, + cell.kernel_size[0]) + cell.weight.set_data(weight) if isinstance(cell, nn.Dense): - cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(), - cell.weight.shape, - cell.weight.dtype)) + if config.dense_init == "TruncatedNormal": + cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(), + cell.weight.shape, + cell.weight.dtype)) + elif config.dense_init == "RandomNormal": + in_channel = cell.in_channels + out_channel = cell.out_channels + weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel) + weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=cell.weight.dtype) + cell.weight.set_data(weight) def init_lr(step_size): """init lr""" @@ -163,6 +178,21 @@ def init_loss_scale(): loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') return loss + +def init_group_params(net): + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + + group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + return group_params + def run_eval(target, model, ckpt_save_dir, cb): """run_eval""" if config.run_eval: @@ -205,18 +235,11 @@ def train_net(): init_weight(net=net) lr = Tensor(init_lr(step_size=step_size)) # define opt - decayed_params = [] - no_decayed_params = [] - for param in net.trainable_params(): - if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: - decayed_params.append(param) - else: - no_decayed_params.append(param) - - group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, - {'params': no_decayed_params}, - {'order_params': net.trainable_params()}] + group_params = init_group_params(net) opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) + if config.optimizer == "LARS": + opt = LARS(opt, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient, + lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name) loss = init_loss_scale() loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) dist_eval_network = ClassifyCorrectCell(net) if config.run_distribute else None