fix model_zoo

This commit is contained in:
panfengfeng 2020-07-21 11:33:42 +08:00
parent 25e587e483
commit a45e29800c
6 changed files with 5 additions and 6 deletions

View File

@ -85,4 +85,3 @@ if __name__ == "__main__":
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
model = Model(net, loss, opt)
model.train(config.epoch_size, train_dataset, callback)

View File

@ -46,7 +46,7 @@ args = parser.parse_args()
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1)
step_size = ds_train.get_dataset_size()
# define fusion network

View File

@ -105,7 +105,7 @@ if __name__ == '__main__':
# define dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
repeat_num=epoch_size,
repeat_num=1,
batch_size=config.batch_size,
target=args_opt.device_target)
step_size = dataset.get_dataset_size()

View File

@ -191,7 +191,7 @@ def train(cloud_args=None):
# dataloader
de_dataset = classification_dataset(args.data_dir, args.image_size,
args.per_batch_size, args.max_epoch,
args.per_batch_size, 1,
args.rank, args.group_size)
de_dataset.map_model = 4 # !!!important
args.steps_per_epoch = de_dataset.get_dataset_size()

View File

@ -59,7 +59,7 @@ if __name__ == '__main__':
max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=cf.epoch_size, batch_size=cf.batch_size)
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
step_size = dataset.get_dataset_size()
# define lr
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num

View File

@ -290,7 +290,7 @@ def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix,
writer.commit()
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0,
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=1, device_num=1, rank=0,
is_training=True, num_parallel_workers=8):
"""Creatr YOLOv3 dataset with MindDataset."""
ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,