forked from mindspore-Ecosystem/mindspore
standardlization of moblienetv2 and resnet50 quant network
This commit is contained in:
parent
c7f461acc7
commit
56806d2297
|
@ -43,7 +43,6 @@ run_ascend()
|
|||
--training_script=${BASEPATH}/../train.py \
|
||||
--dataset_path=$5 \
|
||||
--pre_trained=$6 \
|
||||
--quantization_aware=True \
|
||||
--device_target=$1 &> train.log & # dataset train folder
|
||||
}
|
||||
|
||||
|
@ -75,8 +74,7 @@ run_gpu()
|
|||
python ${BASEPATH}/../train.py \
|
||||
--dataset_path=$4 \
|
||||
--device_target=$1 \
|
||||
--pre_trained=$5 \
|
||||
--quantization_aware=True &> ../train.log & # dataset train folder
|
||||
--pre_trained=$5 &> ../train.log & # dataset train folder
|
||||
}
|
||||
|
||||
if [ $# -gt 6 ] || [ $# -lt 5 ]
|
||||
|
|
|
@ -16,34 +16,12 @@
|
|||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config_ascend = ed({
|
||||
"num_classes": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 256,
|
||||
"data_load_mode": "mindrecord",
|
||||
"epoch_size": 200,
|
||||
"start_epoch": 0,
|
||||
"warmup_epochs": 4,
|
||||
"lr": 0.4,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 4e-5,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 300,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"quantization_aware": False,
|
||||
})
|
||||
|
||||
config_ascend_quant = ed({
|
||||
"num_classes": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 192,
|
||||
"data_load_mode": "mindrecord",
|
||||
"data_load_mode": "mindata",
|
||||
"epoch_size": 60,
|
||||
"start_epoch": 200,
|
||||
"warmup_epochs": 1,
|
||||
|
@ -59,24 +37,6 @@ config_ascend_quant = ed({
|
|||
"quantization_aware": True,
|
||||
})
|
||||
|
||||
config_gpu = ed({
|
||||
"num_classes": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 150,
|
||||
"epoch_size": 200,
|
||||
"warmup_epochs": 4,
|
||||
"lr": 0.8,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 4e-5,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 300,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
})
|
||||
|
||||
config_gpu_quant = ed({
|
||||
"num_classes": 1000,
|
||||
"image_height": 224,
|
||||
|
|
|
@ -26,6 +26,38 @@ from mindspore.ops import functional as F
|
|||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
def _load_param_into_net(model, params_dict):
|
||||
"""
|
||||
load fp32 model parameters to quantization model.
|
||||
|
||||
Args:
|
||||
model: quantization model
|
||||
params_dict: f32 param
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
iterable_dict = {
|
||||
'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
|
||||
'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
|
||||
'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
|
||||
'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
|
||||
'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
|
||||
'moving_variance': iter(
|
||||
[item for item in params_dict.items() if item[0].endswith('moving_variance')]),
|
||||
'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
|
||||
'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
|
||||
}
|
||||
for name, param in model.parameters_and_names():
|
||||
key_name = name.split(".")[-1]
|
||||
if key_name not in iterable_dict.keys():
|
||||
raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
|
||||
value_param = next(iterable_dict[key_name], None)
|
||||
if value_param is not None:
|
||||
param.set_parameter_data(value_param[1].data)
|
||||
print(f'init model param {name} with checkpoint param {value_param[0]}')
|
||||
|
||||
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore import nn
|
|||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.train.quant import quant
|
||||
import mindspore.dataset.engine as de
|
||||
|
@ -33,8 +33,9 @@ import mindspore.dataset.engine as de
|
|||
from src.dataset import create_dataset
|
||||
from src.lr_generator import get_lr
|
||||
from src.utils import Monitor, CrossEntropyWithLabelSmooth
|
||||
from src.config import config_ascend_quant, config_ascend, config_gpu_quant, config_gpu
|
||||
from src.config import config_ascend_quant, config_gpu_quant
|
||||
from src.mobilenetV2 import mobilenetV2
|
||||
from src.utils import _load_param_into_net
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
|
@ -44,7 +45,6 @@ parser = argparse.ArgumentParser(description='Image classification')
|
|||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path')
|
||||
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
|
||||
parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.device_target == "Ascend":
|
||||
|
@ -69,7 +69,7 @@ else:
|
|||
|
||||
|
||||
def train_on_ascend():
|
||||
config = config_ascend_quant if args_opt.quantization_aware else config_ascend
|
||||
config = config_ascend_quant
|
||||
print("training args: {}".format(args_opt))
|
||||
print("training configure: {}".format(config))
|
||||
print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
|
||||
|
@ -101,14 +101,12 @@ def train_on_ascend():
|
|||
# load pre trained ckpt
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
_load_param_into_net(network, param_dict)
|
||||
# convert fusion network to quantization aware network
|
||||
if config.quantization_aware:
|
||||
network = quant.convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quant.convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
|
||||
# get learning rate
|
||||
lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
|
||||
|
@ -141,7 +139,7 @@ def train_on_ascend():
|
|||
|
||||
|
||||
def train_on_gpu():
|
||||
config = config_gpu_quant if args_opt.quantization_aware else config_gpu
|
||||
config = config_gpu_quant
|
||||
print("training args: {}".format(args_opt))
|
||||
print("training configure: {}".format(config))
|
||||
|
||||
|
@ -165,14 +163,15 @@ def train_on_gpu():
|
|||
# resume
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
load_param_into_net(network, param_dict)
|
||||
_load_param_into_net(network, param_dict)
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
if config.quantization_aware:
|
||||
network = quant.convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, True])
|
||||
network = quant.convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, True],
|
||||
freeze_bn=1000000,
|
||||
quant_delay=step_size * 2)
|
||||
|
||||
# get learning rate
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
|
|
@ -16,33 +16,6 @@
|
|||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
quant_set = ed({
|
||||
"quantization_aware": True,
|
||||
})
|
||||
config_noquant = ed({
|
||||
"class_num": 1001,
|
||||
"batch_size": 32,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 90,
|
||||
"pretrained_epoch_size": 1,
|
||||
"buffer_size": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"data_load_mode": "mindrecord",
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 50,
|
||||
"save_checkpoint_path": "./",
|
||||
"warmup_epochs": 0,
|
||||
"lr_decay_mode": "cosine",
|
||||
"use_label_smooth": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_init": 0,
|
||||
"lr_max": 0.1,
|
||||
})
|
||||
config_quant = ed({
|
||||
"class_num": 1001,
|
||||
"batch_size": 32,
|
||||
|
@ -54,7 +27,7 @@ config_quant = ed({
|
|||
"buffer_size": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"data_load_mode": "mindrecord",
|
||||
"data_load_mode": "mindata",
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 50,
|
||||
|
|
|
@ -33,7 +33,7 @@ import mindspore.common.initializer as weight_init
|
|||
from models.resnet_quant import resnet50_quant
|
||||
from src.dataset import create_dataset
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import quant_set, config_quant, config_noquant
|
||||
from src.config import config_quant
|
||||
from src.crossentropy import CrossEntropy
|
||||
from src.utils import _load_param_into_net
|
||||
|
||||
|
@ -44,7 +44,7 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path
|
|||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path')
|
||||
args_opt = parser.parse_args()
|
||||
config = config_quant if quant_set.quantization_aware else config_noquant
|
||||
config = config_quant
|
||||
|
||||
if args_opt.device_target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
@ -110,9 +110,8 @@ if __name__ == '__main__':
|
|||
target=args_opt.device_target)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
if quant_set.quantization_aware:
|
||||
# convert fusion network to quantization aware network
|
||||
net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||
# convert fusion network to quantization aware network
|
||||
net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||
|
||||
# get learning rate
|
||||
lr = get_lr(lr_init=config.lr_init,
|
||||
|
@ -131,11 +130,7 @@ if __name__ == '__main__':
|
|||
config.weight_decay, config.loss_scale)
|
||||
|
||||
# define model
|
||||
if quant_set.quantization_aware:
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
|
||||
else:
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
|
||||
amp_level="O2")
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
time_callback = TimeMonitor(data_size=step_size)
|
||||
|
|
Loading…
Reference in New Issue