forked from mindspore-Ecosystem/mindspore
support multy node training and remove code
This commit is contained in:
parent
1b69923472
commit
b0004a1791
|
@ -5,8 +5,7 @@ This is an example of training DeepLabV3 with PASCAL VOC 2012 dataset in MindSpo
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||||
- Download the VOC 2012 dataset for training.
|
- Download the VOC 2012 dataset for training.
|
||||||
- We need to run `./src/remove_gt_colormap.py` to remove the label colormap.
|
|
||||||
``` bash
|
``` bash
|
||||||
python remove_gt_colormap.py --original_gt_folder GT_FOLDER --output_dir OUTPUT_DIR
|
python remove_gt_colormap.py --original_gt_folder GT_FOLDER --output_dir OUTPUT_DIR
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ DATA_DIR=$2
|
||||||
export MINDSPORE_HCCL_CONFIG_PATH=$1
|
export MINDSPORE_HCCL_CONFIG_PATH=$1
|
||||||
export RANK_TABLE_FILE=$1
|
export RANK_TABLE_FILE=$1
|
||||||
export RANK_SIZE=8
|
export RANK_SIZE=8
|
||||||
|
export DEVICE_NUM=8
|
||||||
PATH_CHECKPOINT=""
|
PATH_CHECKPOINT=""
|
||||||
if [ $# == 3 ]
|
if [ $# == 3 ]
|
||||||
then
|
then
|
||||||
|
@ -37,11 +38,13 @@ avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
|
||||||
core_gap=`expr $avg_core_per_rank \- 1`
|
core_gap=`expr $avg_core_per_rank \- 1`
|
||||||
echo "avg_core_per_rank" $avg_core_per_rank
|
echo "avg_core_per_rank" $avg_core_per_rank
|
||||||
echo "core_gap" $core_gap
|
echo "core_gap" $core_gap
|
||||||
for((i=0;i<RANK_SIZE;i++))
|
export SERVER_ID=0
|
||||||
|
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||||
|
for((i=0;i<DEVICE_NUM;i++))
|
||||||
do
|
do
|
||||||
start=`expr $i \* $avg_core_per_rank`
|
start=`expr $i \* $avg_core_per_rank`
|
||||||
export DEVICE_ID=$i
|
export DEVICE_ID=$i
|
||||||
export RANK_ID=$i
|
export RANK_ID=$((rank_start + i))
|
||||||
export DEPLOY_MODE=0
|
export DEPLOY_MODE=0
|
||||||
export GE_USE_STATIC_MEMORY=1
|
export GE_USE_STATIC_MEMORY=1
|
||||||
end=`expr $start \+ $core_gap`
|
end=`expr $start \+ $core_gap`
|
||||||
|
|
|
@ -1,76 +0,0 @@
|
||||||
# Copyright 2020 The Huawei Authors All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
"""Removes the color map from segmentation annotations.
|
|
||||||
Removes the color map from the ground truth segmentation annotations and save
|
|
||||||
the results to output_dir.
|
|
||||||
"""
|
|
||||||
import glob
|
|
||||||
import argparse
|
|
||||||
import os.path
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_colormap(filename):
|
|
||||||
"""Removes the color map from the annotation.
|
|
||||||
Args:
|
|
||||||
filename: Ground truth annotation filename.
|
|
||||||
Returns:
|
|
||||||
Annotation without color map.
|
|
||||||
"""
|
|
||||||
return np.array(Image.open(filename))
|
|
||||||
|
|
||||||
|
|
||||||
def _save_annotation(annotation, filename):
|
|
||||||
"""Saves the annotation as png file.
|
|
||||||
Args:
|
|
||||||
annotation: Segmentation annotation.
|
|
||||||
filename: Output filename.
|
|
||||||
"""
|
|
||||||
pil_image = Image.fromarray(annotation.astype(dtype=np.uint8))
|
|
||||||
pil_image.save(filename, 'PNG')
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Demo of argparse")
|
|
||||||
parser.add_argument('--original_gt_folder', type=str, default='./VOCdevkit/VOC2012/SegmentationClass',
|
|
||||||
help='Original ground truth annotations.')
|
|
||||||
parser.add_argument('--segmentation_format', type=str, default='png',
|
|
||||||
help='Segmentation format.')
|
|
||||||
parser.add_argument('--output_dir', type=str, default='./VOCdevkit/VOC2012/SegmentationClassRaw',
|
|
||||||
help='folder to save modified ground truth annotations.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Create the output directory if not exists.
|
|
||||||
|
|
||||||
if not os.path.isdir(args.output_dir):
|
|
||||||
os.mkdir(args.output_dir)
|
|
||||||
|
|
||||||
annotations = glob.glob(os.path.join(args.original_gt_folder,
|
|
||||||
'*.' + args.segmentation_format))
|
|
||||||
|
|
||||||
for annotation in annotations:
|
|
||||||
raw_annotation = _remove_colormap(annotation)
|
|
||||||
filename = os.path.basename(annotation)[:-4]
|
|
||||||
_save_annotation(raw_annotation,
|
|
||||||
os.path.join(
|
|
||||||
args.output_dir,
|
|
||||||
filename + '.' + args.segmentation_format))
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
Loading…
Reference in New Issue