forked from mindspore-Ecosystem/mindspore
!21154 Add resnet lars and config.
Merge pull request !21154 from linqingke/resnet
This commit is contained in:
commit
076aa59dab
|
@ -133,7 +133,8 @@ class ParameterProcess:
|
||||||
|
|
||||||
if isinstance(origin_params_copy[0], Parameter):
|
if isinstance(origin_params_copy[0], Parameter):
|
||||||
group_params = [{"params": parameters}]
|
group_params = [{"params": parameters}]
|
||||||
else:
|
return group_params
|
||||||
|
|
||||||
group_params = []
|
group_params = []
|
||||||
params_name = [param.name for param in parameters]
|
params_name = [param.name for param in parameters]
|
||||||
new_params_count = copy.deepcopy(params_name)
|
new_params_count = copy.deepcopy(params_name)
|
||||||
|
@ -169,20 +170,25 @@ class ParameterProcess:
|
||||||
group_params.append({"params": params_value})
|
group_params.append({"params": params_value})
|
||||||
return group_params
|
return group_params
|
||||||
|
|
||||||
|
|
||||||
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
|
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
|
||||||
|
|
||||||
|
|
||||||
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
|
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
|
||||||
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
|
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
|
||||||
"""Apply gradient accumulation to cumulative grad."""
|
"""Apply gradient accumulation to cumulative grad."""
|
||||||
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
|
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
|
||||||
|
|
||||||
|
|
||||||
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
|
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
|
||||||
|
|
||||||
|
|
||||||
@_gradient_clear_op.register("Tensor")
|
@_gradient_clear_op.register("Tensor")
|
||||||
def _clear_grad(cumulative_grad):
|
def _clear_grad(cumulative_grad):
|
||||||
zero_grad = P.ZerosLike()(cumulative_grad)
|
zero_grad = P.ZerosLike()(cumulative_grad)
|
||||||
return F.assign(cumulative_grad, zero_grad)
|
return F.assign(cumulative_grad, zero_grad)
|
||||||
|
|
||||||
|
|
||||||
class GradientAccumulation(Cell):
|
class GradientAccumulation(Cell):
|
||||||
"""
|
"""
|
||||||
After accumulating the gradients of multiple steps, call to optimize its update.
|
After accumulating the gradients of multiple steps, call to optimize its update.
|
||||||
|
|
|
@ -243,6 +243,7 @@ class GradientFreeze:
|
||||||
|
|
||||||
return network, optimizer
|
return network, optimizer
|
||||||
|
|
||||||
|
|
||||||
def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None,
|
def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None,
|
||||||
max_accumulation_step=1):
|
max_accumulation_step=1):
|
||||||
"""Provide freeze network cell."""
|
"""Provide freeze network cell."""
|
||||||
|
|
|
@ -81,6 +81,7 @@ class CommonHeadLastFN(Cell):
|
||||||
x = self.multiplier * x
|
x = self.multiplier * x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LessBN(Cell):
|
class LessBN(Cell):
|
||||||
"""
|
"""
|
||||||
Reduce the number of BN automatically to improve the network performance
|
Reduce the number of BN automatically to improve the network performance
|
||||||
|
|
|
@ -33,6 +33,8 @@ lr_decay_mode: "cosine"
|
||||||
use_label_smooth: True
|
use_label_smooth: True
|
||||||
label_smooth_factor: 0.1
|
label_smooth_factor: 0.1
|
||||||
lr: 0.1
|
lr: 0.1
|
||||||
|
lars_epsilon: 0.0
|
||||||
|
lars_coefficient: 0.001
|
||||||
|
|
||||||
net_name: "resnet101"
|
net_name: "resnet101"
|
||||||
dataset: "imagenet2012"
|
dataset: "imagenet2012"
|
||||||
|
@ -49,6 +51,8 @@ enable_cache: False
|
||||||
cache_session_id: ""
|
cache_session_id: ""
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
acc_mode: "O0"
|
acc_mode: "O0"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
all_reduce_fusion_config:
|
all_reduce_fusion_config:
|
||||||
- 2
|
- 2
|
||||||
- 60
|
- 60
|
||||||
|
|
|
@ -33,6 +33,8 @@ lr_decay_mode: "poly"
|
||||||
lr_init: 0.01
|
lr_init: 0.01
|
||||||
lr_end: 0.00001
|
lr_end: 0.00001
|
||||||
lr_max: 0.1
|
lr_max: 0.1
|
||||||
|
lars_epsilon: 0.0
|
||||||
|
lars_coefficient: 0.001
|
||||||
|
|
||||||
net_name: "resnet18"
|
net_name: "resnet18"
|
||||||
dataset: "cifar10"
|
dataset: "cifar10"
|
||||||
|
@ -49,6 +51,8 @@ enable_cache: False
|
||||||
cache_session_id: ""
|
cache_session_id: ""
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
acc_mode: "O0"
|
acc_mode: "O0"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
|
|
||||||
# Export options
|
# Export options
|
||||||
device_id: 0
|
device_id: 0
|
||||||
|
|
|
@ -35,6 +35,8 @@ label_smooth_factor: 0.1
|
||||||
lr_init: 0
|
lr_init: 0
|
||||||
lr_max: 0.8
|
lr_max: 0.8
|
||||||
lr_end: 0.0
|
lr_end: 0.0
|
||||||
|
lars_epsilon: 0.0
|
||||||
|
lars_coefficient: 0.001
|
||||||
|
|
||||||
net_name: "resnet18"
|
net_name: "resnet18"
|
||||||
dataset: "imagenet2012"
|
dataset: "imagenet2012"
|
||||||
|
@ -51,6 +53,8 @@ enable_cache: False
|
||||||
cache_session_id: ""
|
cache_session_id: ""
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
acc_mode: "O0"
|
acc_mode: "O0"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
|
|
||||||
# Export options
|
# Export options
|
||||||
device_id: 0
|
device_id: 0
|
||||||
|
|
|
@ -35,6 +35,8 @@ label_smooth_factor: 0.1
|
||||||
lr_init: 0
|
lr_init: 0
|
||||||
lr_max: 0.8
|
lr_max: 0.8
|
||||||
lr_end: 0.0
|
lr_end: 0.0
|
||||||
|
lars_epsilon: 0.0
|
||||||
|
lars_coefficient: 0.001
|
||||||
|
|
||||||
net_name: "resnet34"
|
net_name: "resnet34"
|
||||||
dataset: "imagenet2012"
|
dataset: "imagenet2012"
|
||||||
|
@ -51,6 +53,8 @@ enable_cache: False
|
||||||
cache_session_id: ""
|
cache_session_id: ""
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
acc_mode: "O0"
|
acc_mode: "O0"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
|
|
||||||
# Export options
|
# Export options
|
||||||
device_id: 0
|
device_id: 0
|
||||||
|
|
|
@ -33,6 +33,8 @@ lr_decay_mode: "poly"
|
||||||
lr_init: 0.01
|
lr_init: 0.01
|
||||||
lr_end: 0.00001
|
lr_end: 0.00001
|
||||||
lr_max: 0.1
|
lr_max: 0.1
|
||||||
|
lars_epsilon: 0.0
|
||||||
|
lars_coefficient: 0.001
|
||||||
|
|
||||||
net_name: "resnet50"
|
net_name: "resnet50"
|
||||||
dataset: "cifar10"
|
dataset: "cifar10"
|
||||||
|
@ -49,6 +51,8 @@ enable_cache: False
|
||||||
cache_session_id: ""
|
cache_session_id: ""
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
acc_mode: "O0"
|
acc_mode: "O0"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
all_reduce_fusion_config:
|
all_reduce_fusion_config:
|
||||||
- 2
|
- 2
|
||||||
- 115
|
- 115
|
||||||
|
|
|
@ -35,6 +35,8 @@ label_smooth_factor: 0.1
|
||||||
lr_init: 0
|
lr_init: 0
|
||||||
lr_max: 0.8
|
lr_max: 0.8
|
||||||
lr_end: 0.0
|
lr_end: 0.0
|
||||||
|
lars_epsilon: 0.0
|
||||||
|
lars_coefficient: 0.001
|
||||||
|
|
||||||
net_name: "resnet50"
|
net_name: "resnet50"
|
||||||
dataset: "imagenet2012"
|
dataset: "imagenet2012"
|
||||||
|
@ -51,6 +53,8 @@ enable_cache: False
|
||||||
cache_session_id: ""
|
cache_session_id: ""
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
acc_mode: "O1"
|
acc_mode: "O1"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
all_reduce_fusion_config:
|
all_reduce_fusion_config:
|
||||||
- 85
|
- 85
|
||||||
- 160
|
- 160
|
||||||
|
|
|
@ -33,6 +33,8 @@ label_smooth_factor: 0.1
|
||||||
lr_init: 0.05803
|
lr_init: 0.05803
|
||||||
lr_decay: 4.04839
|
lr_decay: 4.04839
|
||||||
lr_end_epoch: 53
|
lr_end_epoch: 53
|
||||||
|
lars_epsilon: 0.0
|
||||||
|
lars_coefficient: 0.001
|
||||||
damping_init: 0.02714
|
damping_init: 0.02714
|
||||||
damping_decay: 0.50036
|
damping_decay: 0.50036
|
||||||
frequency: 834
|
frequency: 834
|
||||||
|
@ -52,6 +54,8 @@ enable_cache: False
|
||||||
cache_session_id: ""
|
cache_session_id: ""
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
acc_mode: "O0"
|
acc_mode: "O0"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
all_reduce_fusion_config:
|
all_reduce_fusion_config:
|
||||||
- 85
|
- 85
|
||||||
- 160
|
- 160
|
||||||
|
|
|
@ -33,6 +33,8 @@ label_smooth_factor: 0.1
|
||||||
lr_init: 0.05672
|
lr_init: 0.05672
|
||||||
lr_decay: 4.9687
|
lr_decay: 4.9687
|
||||||
lr_end_epoch: 50
|
lr_end_epoch: 50
|
||||||
|
lars_epsilon: 0.0
|
||||||
|
lars_coefficient: 0.001
|
||||||
damping_init: 0.02345
|
damping_init: 0.02345
|
||||||
damping_decay: 0.5467
|
damping_decay: 0.5467
|
||||||
frequency: 834
|
frequency: 834
|
||||||
|
@ -52,6 +54,8 @@ enable_cache: False
|
||||||
cache_session_id: ""
|
cache_session_id: ""
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
acc_mode: "O0"
|
acc_mode: "O0"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
all_reduce_fusion_config:
|
all_reduce_fusion_config:
|
||||||
- 85
|
- 85
|
||||||
- 160
|
- 160
|
||||||
|
|
|
@ -35,6 +35,8 @@ label_smooth_factor: 0.1
|
||||||
lr_init: 0
|
lr_init: 0
|
||||||
lr_max: 0.8
|
lr_max: 0.8
|
||||||
lr_end: 0.0
|
lr_end: 0.0
|
||||||
|
lars_epsilon: 0.0
|
||||||
|
lars_coefficient: 0.001
|
||||||
|
|
||||||
net_name: "resnet50"
|
net_name: "resnet50"
|
||||||
dataset: "imagenet2012"
|
dataset: "imagenet2012"
|
||||||
|
@ -51,6 +53,8 @@ enable_cache: False
|
||||||
cache_session_id: ""
|
cache_session_id: ""
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
acc_mode: "O0"
|
acc_mode: "O0"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
all_reduce_fusion_config:
|
all_reduce_fusion_config:
|
||||||
- 85
|
- 85
|
||||||
- 160
|
- 160
|
||||||
|
|
|
@ -26,6 +26,8 @@ save_ckpt: False
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
dtype: "fp16"
|
dtype: "fp16"
|
||||||
acc_mode: "O0"
|
acc_mode: "O0"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
|
|
||||||
# Export options
|
# Export options
|
||||||
device_id: 0
|
device_id: 0
|
||||||
|
|
|
@ -36,6 +36,8 @@ label_smooth_factor: 0.1
|
||||||
lr_init: 0
|
lr_init: 0
|
||||||
lr_end: 0.0001
|
lr_end: 0.0001
|
||||||
lr_max: 0.3
|
lr_max: 0.3
|
||||||
|
lars_epsilon: 0.0
|
||||||
|
lars_coefficient: 0.001
|
||||||
|
|
||||||
net_name: "se-resnet50"
|
net_name: "se-resnet50"
|
||||||
dataset: "imagenet2012"
|
dataset: "imagenet2012"
|
||||||
|
@ -52,6 +54,8 @@ enable_cache: False
|
||||||
cache_session_id: ""
|
cache_session_id: ""
|
||||||
mode_name: "GRAPH"
|
mode_name: "GRAPH"
|
||||||
acc_mode: "O0"
|
acc_mode: "O0"
|
||||||
|
conv_init: "XavierUniform"
|
||||||
|
dense_init: "TruncatedNormal"
|
||||||
all_reduce_fusion_config:
|
all_reduce_fusion_config:
|
||||||
- 1
|
- 1
|
||||||
- 100
|
- 100
|
||||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
|
||||||
|
|
||||||
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
|
def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
|
||||||
fan_in = in_channel * kernel_size * kernel_size
|
fan_in = in_channel * kernel_size * kernel_size
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
scale /= max(1., fan_in)
|
scale /= max(1., fan_in)
|
||||||
|
@ -108,7 +108,7 @@ def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'
|
||||||
|
|
||||||
def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||||
if use_se:
|
if use_se:
|
||||||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
|
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
|
||||||
else:
|
else:
|
||||||
weight_shape = (out_channel, in_channel, 3, 3)
|
weight_shape = (out_channel, in_channel, 3, 3)
|
||||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||||
|
@ -121,7 +121,7 @@ def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||||
|
|
||||||
def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||||
if use_se:
|
if use_se:
|
||||||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
|
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
|
||||||
else:
|
else:
|
||||||
weight_shape = (out_channel, in_channel, 1, 1)
|
weight_shape = (out_channel, in_channel, 1, 1)
|
||||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||||
|
@ -134,7 +134,7 @@ def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||||
|
|
||||||
def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||||
if use_se:
|
if use_se:
|
||||||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
|
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
|
||||||
else:
|
else:
|
||||||
weight_shape = (out_channel, in_channel, 7, 7)
|
weight_shape = (out_channel, in_channel, 7, 7)
|
||||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||||
|
@ -207,7 +207,7 @@ class ResidualBlock(nn.Cell):
|
||||||
self.bn2 = _bn(channel)
|
self.bn2 = _bn(channel)
|
||||||
|
|
||||||
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
|
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
|
||||||
self.bn3 = _bn_last(out_channel)
|
self.bn3 = _bn(out_channel)
|
||||||
if self.se_block:
|
if self.se_block:
|
||||||
self.se_global_pool = P.ReduceMean(keep_dims=False)
|
self.se_global_pool = P.ReduceMean(keep_dims=False)
|
||||||
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
|
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""train resnet."""
|
"""train resnet."""
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.nn.optim import Momentum, thor
|
from mindspore.nn.optim import Momentum, thor, LARS
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.train_thor import ConvertModelUtils
|
from mindspore.train.train_thor import ConvertModelUtils
|
||||||
|
@ -37,6 +38,7 @@ from src.metric import DistAccuracy, ClassifyCorrectCell
|
||||||
from src.model_utils.config import config
|
from src.model_utils.config import config
|
||||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
from src.model_utils.device_adapter import get_rank_id, get_device_num
|
from src.model_utils.device_adapter import get_rank_id, get_device_num
|
||||||
|
from src.resnet import conv_variance_scaling_initializer
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
|
@ -130,13 +132,26 @@ def init_weight(net):
|
||||||
else:
|
else:
|
||||||
for _, cell in net.cells_and_names():
|
for _, cell in net.cells_and_names():
|
||||||
if isinstance(cell, nn.Conv2d):
|
if isinstance(cell, nn.Conv2d):
|
||||||
|
if config.conv_init == "XavierUniform":
|
||||||
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
|
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
|
||||||
cell.weight.shape,
|
cell.weight.shape,
|
||||||
cell.weight.dtype))
|
cell.weight.dtype))
|
||||||
|
elif config.conv_init == "TruncatedNormal":
|
||||||
|
weight = conv_variance_scaling_initializer(cell.in_channels,
|
||||||
|
cell.out_channels,
|
||||||
|
cell.kernel_size[0])
|
||||||
|
cell.weight.set_data(weight)
|
||||||
if isinstance(cell, nn.Dense):
|
if isinstance(cell, nn.Dense):
|
||||||
|
if config.dense_init == "TruncatedNormal":
|
||||||
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
|
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
|
||||||
cell.weight.shape,
|
cell.weight.shape,
|
||||||
cell.weight.dtype))
|
cell.weight.dtype))
|
||||||
|
elif config.dense_init == "RandomNormal":
|
||||||
|
in_channel = cell.in_channels
|
||||||
|
out_channel = cell.out_channels
|
||||||
|
weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
|
||||||
|
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=cell.weight.dtype)
|
||||||
|
cell.weight.set_data(weight)
|
||||||
|
|
||||||
def init_lr(step_size):
|
def init_lr(step_size):
|
||||||
"""init lr"""
|
"""init lr"""
|
||||||
|
@ -163,6 +178,21 @@ def init_loss_scale():
|
||||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def init_group_params(net):
|
||||||
|
decayed_params = []
|
||||||
|
no_decayed_params = []
|
||||||
|
for param in net.trainable_params():
|
||||||
|
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||||
|
decayed_params.append(param)
|
||||||
|
else:
|
||||||
|
no_decayed_params.append(param)
|
||||||
|
|
||||||
|
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||||
|
{'params': no_decayed_params},
|
||||||
|
{'order_params': net.trainable_params()}]
|
||||||
|
return group_params
|
||||||
|
|
||||||
def run_eval(target, model, ckpt_save_dir, cb):
|
def run_eval(target, model, ckpt_save_dir, cb):
|
||||||
"""run_eval"""
|
"""run_eval"""
|
||||||
if config.run_eval:
|
if config.run_eval:
|
||||||
|
@ -205,18 +235,11 @@ def train_net():
|
||||||
init_weight(net=net)
|
init_weight(net=net)
|
||||||
lr = Tensor(init_lr(step_size=step_size))
|
lr = Tensor(init_lr(step_size=step_size))
|
||||||
# define opt
|
# define opt
|
||||||
decayed_params = []
|
group_params = init_group_params(net)
|
||||||
no_decayed_params = []
|
|
||||||
for param in net.trainable_params():
|
|
||||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
|
||||||
decayed_params.append(param)
|
|
||||||
else:
|
|
||||||
no_decayed_params.append(param)
|
|
||||||
|
|
||||||
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
|
||||||
{'params': no_decayed_params},
|
|
||||||
{'order_params': net.trainable_params()}]
|
|
||||||
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
||||||
|
if config.optimizer == "LARS":
|
||||||
|
opt = LARS(opt, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient,
|
||||||
|
lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name)
|
||||||
loss = init_loss_scale()
|
loss = init_loss_scale()
|
||||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||||
dist_eval_network = ClassifyCorrectCell(net) if config.run_distribute else None
|
dist_eval_network = ClassifyCorrectCell(net) if config.run_distribute else None
|
||||||
|
|
Loading…
Reference in New Issue