From a45e29800c6d824aedb2378bfb486ca40f34836c Mon Sep 17 00:00:00 2001 From: panfengfeng Date: Tue, 21 Jul 2020 11:33:42 +0800 Subject: [PATCH] fix model_zoo --- model_zoo/deeplabv3/train.py | 1 - model_zoo/lenet_quant/train_quant.py | 2 +- model_zoo/resnet50_quant/train.py | 2 +- model_zoo/resnext50/train.py | 2 +- model_zoo/warpctc/train.py | 2 +- model_zoo/yolov3_resnet18/src/dataset.py | 2 +- 6 files changed, 5 insertions(+), 6 deletions(-) diff --git a/model_zoo/deeplabv3/train.py b/model_zoo/deeplabv3/train.py index 39d50e51ccd..56ef5b02bb0 100644 --- a/model_zoo/deeplabv3/train.py +++ b/model_zoo/deeplabv3/train.py @@ -85,4 +85,3 @@ if __name__ == "__main__": opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) model = Model(net, loss, opt) model.train(config.epoch_size, train_dataset, callback) - \ No newline at end of file diff --git a/model_zoo/lenet_quant/train_quant.py b/model_zoo/lenet_quant/train_quant.py index 3a87ccc70d0..33c322f4b52 100644 --- a/model_zoo/lenet_quant/train_quant.py +++ b/model_zoo/lenet_quant/train_quant.py @@ -46,7 +46,7 @@ args = parser.parse_args() if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) - ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) + ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1) step_size = ds_train.get_dataset_size() # define fusion network diff --git a/model_zoo/resnet50_quant/train.py b/model_zoo/resnet50_quant/train.py index b026f972788..3d0656b80d7 100755 --- a/model_zoo/resnet50_quant/train.py +++ b/model_zoo/resnet50_quant/train.py @@ -105,7 +105,7 @@ if __name__ == '__main__': # define dataset dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, - repeat_num=epoch_size, + repeat_num=1, batch_size=config.batch_size, target=args_opt.device_target) step_size = dataset.get_dataset_size() diff --git a/model_zoo/resnext50/train.py b/model_zoo/resnext50/train.py index 29ccd9b00c6..ec2e33aba3a 100644 --- a/model_zoo/resnext50/train.py +++ b/model_zoo/resnext50/train.py @@ -191,7 +191,7 @@ def train(cloud_args=None): # dataloader de_dataset = classification_dataset(args.data_dir, args.image_size, - args.per_batch_size, args.max_epoch, + args.per_batch_size, 1, args.rank, args.group_size) de_dataset.map_model = 4 # !!!important args.steps_per_epoch = de_dataset.get_dataset_size() diff --git a/model_zoo/warpctc/train.py b/model_zoo/warpctc/train.py index 651d2a73a4d..8b5171f70a7 100755 --- a/model_zoo/warpctc/train.py +++ b/model_zoo/warpctc/train.py @@ -59,7 +59,7 @@ if __name__ == '__main__': max_captcha_digits = cf.max_captcha_digits input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 # create dataset - dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=cf.epoch_size, batch_size=cf.batch_size) + dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size) step_size = dataset.get_dataset_size() # define lr lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num diff --git a/model_zoo/yolov3_resnet18/src/dataset.py b/model_zoo/yolov3_resnet18/src/dataset.py index f85b442209c..7c5177a3fef 100644 --- a/model_zoo/yolov3_resnet18/src/dataset.py +++ b/model_zoo/yolov3_resnet18/src/dataset.py @@ -290,7 +290,7 @@ def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix, writer.commit() -def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0, +def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=1, device_num=1, rank=0, is_training=True, num_parallel_workers=8): """Creatr YOLOv3 dataset with MindDataset.""" ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,