!14962 clean pylint in modelzoo

From: @zhao_ting_v
Reviewed-by: @oacjiewen,@wuxuejian
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-04-13 16:46:57 +08:00 committed by Gitee
commit cac91018ad
3 changed files with 12 additions and 10 deletions

View File

@ -102,7 +102,7 @@ def main(args):
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
cfg.logger.info('load model %s success' % str(cfg.pretrained))
cfg.logger.info('load model %s success.' % cfg.pretrained)
# optimizer and lr scheduler
lr = warmup_step(cfg, gamma=0.9)

View File

@ -328,6 +328,6 @@ if __name__ == '__main__':
log_path = os.path.join(arg.ckpt_path, 'logs')
arg.logger = get_logger(log_path, arg.local_rank)
arg.logger.info('Config\n\n%s\n' % str(pformat(arg)))
arg.logger.info('Config\n\n{}\n'.format(pformat(arg)))
main(arg)

View File

@ -44,8 +44,8 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs
random.seed(1)
np.random.seed(1)
def main():
def init_argument():
"""init config argument."""
parser = argparse.ArgumentParser(description='Cifar10 classification')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--data_dir', type=str, default='', help='image label list file, e.g. /home/label.txt')
@ -78,23 +78,25 @@ def main():
# logger
cfg.outputs_dir = os.path.join(cfg.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank)
loss_meter = AverageMeter('loss')
# Show cfg
cfg.logger.save_args(cfg)
return cfg
def main():
cfg = init_argument()
loss_meter = AverageMeter('loss')
# dataloader
cfg.logger.info('start create dataloader')
de_dataset, steps_per_epoch, class_num = get_de_dataset(cfg)
cfg.steps_per_epoch = steps_per_epoch
cfg.logger.info('step per epoch: ' + str(cfg.steps_per_epoch))
cfg.logger.info('step per epoch: %d' % cfg.steps_per_epoch)
de_dataloader = de_dataset.create_tuple_iterator()
cfg.logger.info('class num original: ' + str(class_num))
cfg.logger.info('class num original: %d' % class_num)
if class_num % 16 != 0:
class_num = (class_num // 16 + 1) * 16
cfg.class_num = class_num
cfg.logger.info('change the class num to :' + str(cfg.class_num))
cfg.logger.info('change the class num to: %d' % cfg.class_num)
cfg.logger.info('end create dataloader')
# backbone and loss
@ -117,7 +119,7 @@ def main():
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
cfg.logger.info('load model {} success'.format(cfg.pretrained))
cfg.logger.info('load model %s success' % cfg.pretrained)
# mixed precision training
network.add_flags_recursive(fp16=True)