forked from mindspore-Ecosystem/mindspore
!19948 fix tinydarknet
Merge pull request !19948 from wanglin/fix_wl_master
This commit is contained in:
commit
6cd41cbe74
|
@ -120,17 +120,6 @@ def modelarts_pre_process():
|
|||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
if config.dataset_name == "imagenet":
|
||||
dataset = create_dataset_imagenet(config.train_data_dir, 1)
|
||||
elif config.dataset_name == "cifar10":
|
||||
dataset = create_dataset_cifar(dataset_path=config.train_data_dir,
|
||||
do_train=True,
|
||||
repeat_num=1,
|
||||
batch_size=config.batch_size,
|
||||
target=config.device_target)
|
||||
else:
|
||||
raise ValueError("Unsupported dataset.")
|
||||
|
||||
# set context
|
||||
device_target = config.device_target
|
||||
|
||||
|
@ -149,6 +138,16 @@ def run_train():
|
|||
init()
|
||||
rank = get_rank_id()
|
||||
|
||||
if config.dataset_name == "imagenet":
|
||||
dataset = create_dataset_imagenet(config.train_data_dir, 1)
|
||||
elif config.dataset_name == "cifar10":
|
||||
dataset = create_dataset_cifar(dataset_path=config.train_data_dir,
|
||||
do_train=True,
|
||||
repeat_num=1,
|
||||
batch_size=config.batch_size,
|
||||
target=config.device_target)
|
||||
else:
|
||||
raise ValueError("Unsupported dataset.")
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
net = TinyDarkNet(num_classes=config.num_classes)
|
||||
|
|
Loading…
Reference in New Issue