!19948 fix tinydarknet

Merge pull request !19948 from wanglin/fix_wl_master
This commit is contained in:
i-robot 2021-07-10 14:28:41 +00:00 committed by Gitee
commit 6cd41cbe74
1 changed files with 10 additions and 11 deletions

View File

@ -120,17 +120,6 @@ def modelarts_pre_process():
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
if config.dataset_name == "imagenet":
dataset = create_dataset_imagenet(config.train_data_dir, 1)
elif config.dataset_name == "cifar10":
dataset = create_dataset_cifar(dataset_path=config.train_data_dir,
do_train=True,
repeat_num=1,
batch_size=config.batch_size,
target=config.device_target)
else:
raise ValueError("Unsupported dataset.")
# set context
device_target = config.device_target
@ -149,6 +138,16 @@ def run_train():
init()
rank = get_rank_id()
if config.dataset_name == "imagenet":
dataset = create_dataset_imagenet(config.train_data_dir, 1)
elif config.dataset_name == "cifar10":
dataset = create_dataset_cifar(dataset_path=config.train_data_dir,
do_train=True,
repeat_num=1,
batch_size=config.batch_size,
target=config.device_target)
else:
raise ValueError("Unsupported dataset.")
batch_num = dataset.get_dataset_size()
net = TinyDarkNet(num_classes=config.num_classes)