From 8d4671b04e7eb03d4253b588a0cd9e49652aa6f1 Mon Sep 17 00:00:00 2001 From: l00486551 Date: Thu, 24 Jun 2021 15:19:23 +0800 Subject: [PATCH] fix data_path conflict --- model_zoo/official/cv/tinydarknet/cifar10_config.yaml | 4 ++-- model_zoo/official/cv/tinydarknet/scripts/run_train_cpu.sh | 2 +- model_zoo/official/cv/tinydarknet/train.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/model_zoo/official/cv/tinydarknet/cifar10_config.yaml b/model_zoo/official/cv/tinydarknet/cifar10_config.yaml index c41f26f53a2..fd343e01427 100644 --- a/model_zoo/official/cv/tinydarknet/cifar10_config.yaml +++ b/model_zoo/official/cv/tinydarknet/cifar10_config.yaml @@ -26,8 +26,8 @@ momentum: 0.9 weight_decay: 0.0001 image_height: 227 image_width: 227 -train_data_dir: './dataset/imagenet_original/train/' -val_data_dir: './dataset/imagenet_original/val/' +train_data_dir: './data/cifar10_train/' +val_data_dir: './data/cifar10_val/' keep_checkpoint_max: 1 checkpoint_path: './scripts/train_parallel4/ckpt_4/train_tinydarknet_imagenet-300_1251.ckpt' onnx_filename: 'tinydarknet.onnx' diff --git a/model_zoo/official/cv/tinydarknet/scripts/run_train_cpu.sh b/model_zoo/official/cv/tinydarknet/scripts/run_train_cpu.sh index ede0b84a981..7a5276c46e6 100644 --- a/model_zoo/official/cv/tinydarknet/scripts/run_train_cpu.sh +++ b/model_zoo/official/cv/tinydarknet/scripts/run_train_cpu.sh @@ -53,5 +53,5 @@ cp ./*.yaml ./train_cpu echo "start training for device CPU" cd ./train_cpu || exit env > env.log -python train.py --device_target=CPU --data_path=$PATH1 --dataset_name=$2 --config_path=$CONFIG_FILE --lr_init=0.01> ./train.log 2>&1 & +python train.py --device_target=CPU --train_data_dir=$PATH1 --dataset_name=$2 --config_path=$CONFIG_FILE> ./train.log 2>&1 & cd .. diff --git a/model_zoo/official/cv/tinydarknet/train.py b/model_zoo/official/cv/tinydarknet/train.py index cdf96134ae9..6e38fb5663b 100644 --- a/model_zoo/official/cv/tinydarknet/train.py +++ b/model_zoo/official/cv/tinydarknet/train.py @@ -121,9 +121,9 @@ def modelarts_pre_process(): @moxing_wrapper(pre_process=modelarts_pre_process) def run_train(): if config.dataset_name == "imagenet": - dataset = create_dataset_imagenet(config.data_path, 1) + dataset = create_dataset_imagenet(config.train_data_dir, 1) elif config.dataset_name == "cifar10": - dataset = create_dataset_cifar(dataset_path=config.data_path, + dataset = create_dataset_cifar(dataset_path=config.train_data_dir, do_train=True, repeat_num=1, batch_size=config.batch_size,