forked from OSSInnovation/mindspore
!3261 fix model zoo of get daatset size
Merge pull request !3261 from panfengfeng/fix_model_zoo_of_get_dataset_size
This commit is contained in:
commit
2d1ad06439
|
@ -85,4 +85,3 @@ if __name__ == "__main__":
|
|||
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
|
||||
model = Model(net, loss, opt)
|
||||
model.train(config.epoch_size, train_dataset, callback)
|
||||
|
|
@ -46,7 +46,7 @@ args = parser.parse_args()
|
|||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
|
||||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1)
|
||||
step_size = ds_train.get_dataset_size()
|
||||
|
||||
# define fusion network
|
||||
|
|
|
@ -105,7 +105,7 @@ if __name__ == '__main__':
|
|||
# define dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=True,
|
||||
repeat_num=epoch_size,
|
||||
repeat_num=1,
|
||||
batch_size=config.batch_size,
|
||||
target=args_opt.device_target)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
|
|
@ -191,7 +191,7 @@ def train(cloud_args=None):
|
|||
|
||||
# dataloader
|
||||
de_dataset = classification_dataset(args.data_dir, args.image_size,
|
||||
args.per_batch_size, args.max_epoch,
|
||||
args.per_batch_size, 1,
|
||||
args.rank, args.group_size)
|
||||
de_dataset.map_model = 4 # !!!important
|
||||
args.steps_per_epoch = de_dataset.get_dataset_size()
|
||||
|
|
|
@ -59,7 +59,7 @@ if __name__ == '__main__':
|
|||
max_captcha_digits = cf.max_captcha_digits
|
||||
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=cf.epoch_size, batch_size=cf.batch_size)
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# define lr
|
||||
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num
|
||||
|
|
|
@ -290,7 +290,7 @@ def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix,
|
|||
writer.commit()
|
||||
|
||||
|
||||
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0,
|
||||
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=1, device_num=1, rank=0,
|
||||
is_training=True, num_parallel_workers=8):
|
||||
"""Creatr YOLOv3 dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
|
||||
|
|
Loading…
Reference in New Issue