forked from mindspore-Ecosystem/mindspore
!14962 clean pylint in modelzoo
From: @zhao_ting_v Reviewed-by: @oacjiewen,@wuxuejian Signed-off-by: @wuxuejian
This commit is contained in:
commit
cac91018ad
|
@ -102,7 +102,7 @@ def main(args):
|
|||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
cfg.logger.info('load model %s success' % str(cfg.pretrained))
|
||||
cfg.logger.info('load model %s success.' % cfg.pretrained)
|
||||
|
||||
# optimizer and lr scheduler
|
||||
lr = warmup_step(cfg, gamma=0.9)
|
||||
|
|
|
@ -328,6 +328,6 @@ if __name__ == '__main__':
|
|||
log_path = os.path.join(arg.ckpt_path, 'logs')
|
||||
arg.logger = get_logger(log_path, arg.local_rank)
|
||||
|
||||
arg.logger.info('Config\n\n%s\n' % str(pformat(arg)))
|
||||
arg.logger.info('Config\n\n{}\n'.format(pformat(arg)))
|
||||
|
||||
main(arg)
|
||||
|
|
|
@ -44,8 +44,8 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs
|
|||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
|
||||
|
||||
def main():
|
||||
def init_argument():
|
||||
"""init config argument."""
|
||||
parser = argparse.ArgumentParser(description='Cifar10 classification')
|
||||
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
||||
parser.add_argument('--data_dir', type=str, default='', help='image label list file, e.g. /home/label.txt')
|
||||
|
@ -78,23 +78,25 @@ def main():
|
|||
# logger
|
||||
cfg.outputs_dir = os.path.join(cfg.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank)
|
||||
loss_meter = AverageMeter('loss')
|
||||
|
||||
# Show cfg
|
||||
cfg.logger.save_args(cfg)
|
||||
return cfg
|
||||
|
||||
def main():
|
||||
cfg = init_argument()
|
||||
loss_meter = AverageMeter('loss')
|
||||
# dataloader
|
||||
cfg.logger.info('start create dataloader')
|
||||
de_dataset, steps_per_epoch, class_num = get_de_dataset(cfg)
|
||||
cfg.steps_per_epoch = steps_per_epoch
|
||||
cfg.logger.info('step per epoch: ' + str(cfg.steps_per_epoch))
|
||||
cfg.logger.info('step per epoch: %d' % cfg.steps_per_epoch)
|
||||
de_dataloader = de_dataset.create_tuple_iterator()
|
||||
|
||||
cfg.logger.info('class num original: ' + str(class_num))
|
||||
cfg.logger.info('class num original: %d' % class_num)
|
||||
if class_num % 16 != 0:
|
||||
class_num = (class_num // 16 + 1) * 16
|
||||
cfg.class_num = class_num
|
||||
cfg.logger.info('change the class num to :' + str(cfg.class_num))
|
||||
cfg.logger.info('change the class num to: %d' % cfg.class_num)
|
||||
cfg.logger.info('end create dataloader')
|
||||
|
||||
# backbone and loss
|
||||
|
@ -117,7 +119,7 @@ def main():
|
|||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
cfg.logger.info('load model {} success'.format(cfg.pretrained))
|
||||
cfg.logger.info('load model %s success' % cfg.pretrained)
|
||||
|
||||
# mixed precision training
|
||||
network.add_flags_recursive(fp16=True)
|
||||
|
|
Loading…
Reference in New Issue