forked from OSSInnovation/mindspore
modify resnet101 scripts for pylint
This commit is contained in:
parent
99bbb3a3b2
commit
3cb692bea1
|
@ -54,9 +54,6 @@ Parameters for both training and evaluating can be set in config.py.
|
|||
"save_checkpoint_steps": 500, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
|
||||
"keep_checkpoint_max": 40, # only keep the last keep_checkpoint_max checkpoint
|
||||
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
|
||||
"lr_init": 0.01, # initial learning rate
|
||||
"lr_end": 0.00001, # final learning rate
|
||||
"lr_max": 0.1, # maximum learning rate
|
||||
"warmup_epochs": 0, # number of warmup epoch
|
||||
"lr_decay_mode": "cosine" # decay mode for generating learning rate
|
||||
"label_smooth": 1, # label_smooth
|
||||
|
|
|
@ -31,9 +31,6 @@ config = ed({
|
|||
"save_checkpoint_steps": 500,
|
||||
"keep_checkpoint_max": 40,
|
||||
"save_checkpoint_path": "./",
|
||||
"lr_init": 0.01,
|
||||
"lr_end": 0.00001,
|
||||
"lr_max": 0.1,
|
||||
"warmup_epochs": 0,
|
||||
"lr_decay_mode": "cosine",
|
||||
"label_smooth": 1,
|
||||
|
|
|
@ -50,63 +50,3 @@ def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch):
|
|||
lr = base_lr * decayed
|
||||
lr_each_step.append(lr)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): total steps of the training
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
if lr_decay_mode == 'steps':
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
elif lr_decay_mode == 'poly':
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
else:
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
||||
|
|
|
@ -19,7 +19,7 @@ import argparse
|
|||
import random
|
||||
import numpy as np
|
||||
from dataset import create_dataset
|
||||
from lr_generator import get_lr, warmup_cosine_annealing_lr
|
||||
from lr_generator import warmup_cosine_annealing_lr
|
||||
from config import config
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
|
@ -32,9 +32,9 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
|||
import mindspore.dataset.engine as de
|
||||
from mindspore.communication.management import init
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
from crossentropy import CrossEntropy
|
||||
from var_init import default_recurisive_init, KaimingNormal
|
||||
import mindspore.common.initializer as weight_init
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
|
@ -72,7 +72,7 @@ if __name__ == '__main__':
|
|||
net = resnet101(class_num=config.class_num)
|
||||
# weight init
|
||||
default_recurisive_init(net)
|
||||
for name, cell in net.cells_and_names():
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.default_input = weight_init.initializer(KaimingNormal(a=math.sqrt(5),
|
||||
mode='fan_out', nonlinearity='relu'),
|
||||
|
@ -83,17 +83,12 @@ if __name__ == '__main__':
|
|||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
if args_opt.do_train:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
|
||||
repeat_num=epoch_size, batch_size=config.batch_size)
|
||||
repeat_num=epoch_size, batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
# learning rate strategy
|
||||
if config.lr_decay_mode == 'cosine':
|
||||
lr = Tensor(warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size))
|
||||
else:
|
||||
lr = Tensor(get_lr(global_step=0, lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
|
||||
warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size,
|
||||
lr_decay_mode='poly'))
|
||||
# learning rate strategy with cosine
|
||||
lr = Tensor(warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size))
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
config.weight_decay, config.loss_scale)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', keep_batchnorm_fp32=False,
|
||||
|
|
|
@ -18,10 +18,10 @@ import numpy as np
|
|||
from mindspore.common import initializer as init
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
r"""Return the recommended gain value for the given nonlinearity function.
|
||||
The values are as follows:
|
||||
The values are as follows:
|
||||
================= ====================================================
|
||||
nonlinearity gain
|
||||
================= ====================================================
|
||||
|
@ -37,12 +37,13 @@ def calculate_gain(nonlinearity, param=None):
|
|||
param: optional parameter for the non-linear function
|
||||
"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
gain = 0
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
return 1
|
||||
gain = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
return 5.0 / 3
|
||||
gain = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
return math.sqrt(2.0)
|
||||
gain = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
|
@ -51,15 +52,16 @@ def calculate_gain(nonlinearity, param=None):
|
|||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
gain = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
return gain
|
||||
|
||||
def _calculate_correct_fan(array, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
@ -83,13 +85,12 @@ def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|||
backwards pass.
|
||||
nonlinearity: the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
"""
|
||||
"""
|
||||
fan = _calculate_correct_fan(array, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
return np.random.uniform(-bound, bound, array.shape)
|
||||
|
||||
|
||||
def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
|
@ -97,12 +98,10 @@ def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|||
performance on ImageNet classification` - He, K. et al. (2015), using a
|
||||
normal distribution. The resulting tensor will have values sampled from
|
||||
:math:`\mathcal{N}(0, \text{std}^2)` where
|
||||
|
||||
.. math::
|
||||
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
|
||||
|
||||
Also known as He initialization.
|
||||
|
||||
|
||||
Args:
|
||||
array: an n-dimensional `tensor`
|
||||
a: the negative slope of the rectifier used after this layer (only
|
||||
|
@ -118,13 +117,12 @@ def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
return np.random.normal(0, std, array.shape)
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(array):
|
||||
"""calculate the fan_in and fan_out for input array"""
|
||||
dimensions = len(array.shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
|
||||
|
||||
num_input_fmaps = array.shape[1]
|
||||
num_output_fmaps = array.shape[0]
|
||||
receptive_field_size = 1
|
||||
|
@ -132,19 +130,30 @@ def _calculate_fan_in_and_fan_out(array):
|
|||
receptive_field_size = array[0][0].size
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def assignment(arr, num):
|
||||
"""Assign the value of num to arr"""
|
||||
if arr.shape == ():
|
||||
arr = arr.reshape((1))
|
||||
arr[:] = num
|
||||
arr = arr.reshape(())
|
||||
else:
|
||||
if isinstance(num, np.ndarray):
|
||||
arr[:] = num[:]
|
||||
else:
|
||||
arr[:] = num
|
||||
return arr
|
||||
|
||||
class KaimingUniform(init.Initializer):
|
||||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(KaimingUniform, self).__init__()
|
||||
self.a = a
|
||||
self.mode = mode
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def _initialize(self, arr):
|
||||
tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
|
||||
init._assignment(arr, tmp)
|
||||
assignment(arr, tmp)
|
||||
|
||||
class KaimingNormal(init.Initializer):
|
||||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
|
@ -152,33 +161,32 @@ class KaimingNormal(init.Initializer):
|
|||
self.a = a
|
||||
self.mode = mode
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def _initialize(self, arr):
|
||||
tmp = kaiming_normal_(arr, self.a, self.mode, self.nonlinearity)
|
||||
init._assignment(arr, tmp)
|
||||
assignment(arr, tmp)
|
||||
|
||||
def default_recurisive_init(custom_cell):
|
||||
"""weight init for conv2d and dense"""
|
||||
for name, cell in custom_cell.cells_and_names():
|
||||
for _, cell in custom_cell.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.default_input.shape(),
|
||||
cell.weight.default_input.dtype())
|
||||
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.default_input.shape(),
|
||||
cell.weight.default_input.dtype())
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound,
|
||||
cell.bias.default_input.shape()),
|
||||
cell.bias.default_input.dtype())
|
||||
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound,
|
||||
cell.bias.default_input.shape()),
|
||||
cell.bias.default_input.dtype())
|
||||
elif isinstance(cell, nn.Dense):
|
||||
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.default_input.shape(),
|
||||
cell.weight.default_input.dtype())
|
||||
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.default_input.shape(),
|
||||
cell.weight.default_input.dtype())
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound,
|
||||
cell.bias.default_input.shape()),
|
||||
cell.bias.default_input.dtype())
|
||||
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound,
|
||||
cell.bias.default_input.shape()),
|
||||
cell.bias.default_input.dtype())
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
pass
|
||||
|
|
|
@ -279,4 +279,4 @@ def resnet101(class_num=1001):
|
|||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
||||
class_num)
|
||||
|
|
Loading…
Reference in New Issue