This commit is contained in:
unknown 2020-05-29 11:47:13 +08:00
parent 52c59900a7
commit eef3c58b5e
5 changed files with 69 additions and 6 deletions

View File

@ -0,0 +1,64 @@
Deeplab-V3 Example
Description
This is an example of training DeepLabv3 with PASCAL VOC 2012 dataset in MindSpore.
Paper Rethinking Atrous Convolution for Semantic Image Segmentation
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam
Requirements
Install MindSpore.
Download the VOC 2012 dataset for training.
For more information, please check the resources below
MindSpore tutorials
MindSpore API
Notes: If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file.
Running the Example
Training
Set options in config.py.
Run run_standalone_train.sh for non-distributed training.
sh scripts/run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_DIR
Run run_distribute_train.sh for distributed training.
sh scripts/run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_DIR MINDSPORE_HCCL_CONFIG_PATH
Evaluation
Set options in evaluation_config.py. Make sure the 'data_file' and 'finetune_ckpt' are set to your own path.
Run run_eval.sh for evaluation.
sh scripts/run_eval.sh DEVICE_ID DATA_DIR
Options and Parameters
It contains of parameters of Deeplab-V3 model and options for training, which is set in file config.py.
Options:
config.py:
learning_rate Learning rate, default is 0.0014.
weight_decay Weight decay, default is 5e-5.
momentum Momentum, default is 0.97.
crop_size Image crop size [height, width] during training, default is 513.
eval_scales The scales to resize images for evaluation, default is [0.5, 0.75, 1.0, 1.25, 1.5, 1.75].
output_stride The ratio of input to output spatial resolution, default is 16.
ignore_label Ignore label value, default is 255.
seg_num_classes Number of semantic classes, including the background class (if exists).
foreground classes + 1 background class in the PASCAL VOC 2012 dataset, default is 21.
fine_tune_batch_norm Fine tune the batch norm parameters or not, default is False.
atrous_rates Atrous rates for atrous spatial pyramid pooling, default is None.
decoder_output_stride The ratio of input to output spatial resolution when employing decoder
to refine segmentation results, default is None.
image_pyramid Input scales for multi-scale feature extraction, default is None.
Parameters:
Parameters for dataset and network:
distribute Run distribute, default is false.
epoch_size Epoch size, default is 6.
batch_size batch size of input dataset: N, default is 2.
data_url Train/Evaluation data url, required.
checkpoint_url Checkpoint path, default is None.
enable_save_ckpt Enable save checkpoint, default is true.
save_checkpoint_steps Save checkpoint steps, default is 1000.
save_checkpoint_num Save checkpoint numbers, default is 1.

View File

@ -26,7 +26,7 @@ parser = argparse.ArgumentParser(description="Deeplabv3 evaluation")
parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument('--batch_size', type=int, default=2, help='Batch size.')
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
parser.add_argument('--data_url', required=True, default=None, help='Evaluation data url')
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
args_opt = parser.parse_args()

View File

@ -15,8 +15,8 @@
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_eval.sh DEVICE_ID EPOCH_SIZE DATA_DIR"
echo "for example: bash run_eval.sh 0 /path/zh-wiki/ "
echo "bash run_eval.sh DEVICE_ID DATA_DIR"
echo "for example: bash run_eval.sh /path/zh-wiki/ "
echo "=============================================================================================================="
DEVICE_ID=$1

View File

@ -27,13 +27,12 @@ from src.config import config
parser = argparse.ArgumentParser(description="Deeplabv3 training")
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.')
parser.add_argument('--epoch_size', type=int, default=6, help='Epoch size.')
parser.add_argument('--batch_size', type=int, default=2, help='Batch size.')
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
parser.add_argument('--max_checkpoint_num', type=int, default=5, help='Max checkpoint number.')
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
args_opt = parser.parse_args()
@ -80,7 +79,7 @@ if __name__ == "__main__":
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
callback.append(ckpoint_cb)
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)