forked from mindspore-Ecosystem/mindspore
fix model_zoo
This commit is contained in:
parent
25e587e483
commit
a45e29800c
|
@ -85,4 +85,3 @@ if __name__ == "__main__":
|
||||||
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
|
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
|
||||||
model = Model(net, loss, opt)
|
model = Model(net, loss, opt)
|
||||||
model.train(config.epoch_size, train_dataset, callback)
|
model.train(config.epoch_size, train_dataset, callback)
|
||||||
|
|
|
@ -46,7 +46,7 @@ args = parser.parse_args()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
|
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1)
|
||||||
step_size = ds_train.get_dataset_size()
|
step_size = ds_train.get_dataset_size()
|
||||||
|
|
||||||
# define fusion network
|
# define fusion network
|
||||||
|
|
|
@ -105,7 +105,7 @@ if __name__ == '__main__':
|
||||||
# define dataset
|
# define dataset
|
||||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
repeat_num=epoch_size,
|
repeat_num=1,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
target=args_opt.device_target)
|
target=args_opt.device_target)
|
||||||
step_size = dataset.get_dataset_size()
|
step_size = dataset.get_dataset_size()
|
||||||
|
|
|
@ -191,7 +191,7 @@ def train(cloud_args=None):
|
||||||
|
|
||||||
# dataloader
|
# dataloader
|
||||||
de_dataset = classification_dataset(args.data_dir, args.image_size,
|
de_dataset = classification_dataset(args.data_dir, args.image_size,
|
||||||
args.per_batch_size, args.max_epoch,
|
args.per_batch_size, 1,
|
||||||
args.rank, args.group_size)
|
args.rank, args.group_size)
|
||||||
de_dataset.map_model = 4 # !!!important
|
de_dataset.map_model = 4 # !!!important
|
||||||
args.steps_per_epoch = de_dataset.get_dataset_size()
|
args.steps_per_epoch = de_dataset.get_dataset_size()
|
||||||
|
|
|
@ -59,7 +59,7 @@ if __name__ == '__main__':
|
||||||
max_captcha_digits = cf.max_captcha_digits
|
max_captcha_digits = cf.max_captcha_digits
|
||||||
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
|
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
|
||||||
# create dataset
|
# create dataset
|
||||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=cf.epoch_size, batch_size=cf.batch_size)
|
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
|
||||||
step_size = dataset.get_dataset_size()
|
step_size = dataset.get_dataset_size()
|
||||||
# define lr
|
# define lr
|
||||||
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num
|
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num
|
||||||
|
|
|
@ -290,7 +290,7 @@ def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix,
|
||||||
writer.commit()
|
writer.commit()
|
||||||
|
|
||||||
|
|
||||||
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0,
|
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=1, device_num=1, rank=0,
|
||||||
is_training=True, num_parallel_workers=8):
|
is_training=True, num_parallel_workers=8):
|
||||||
"""Creatr YOLOv3 dataset with MindDataset."""
|
"""Creatr YOLOv3 dataset with MindDataset."""
|
||||||
ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
|
ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
|
||||||
|
|
Loading…
Reference in New Issue