!18444 merge psenet&retinanet&cnnctc&cnn_direction_model
Merge pull request !18444 from Maige/master
This commit is contained in:
commit
ab6adbef02
|
@ -93,18 +93,23 @@ sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH]
|
|||
│ ├──run_standalone_eval_ascend.sh // evaluate in ascend
|
||||
│ ├──run_standalone_train_ascend.sh // train standalone in ascend
|
||||
├── src
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──cnn_direction_model.py // cnn_direction_model architecture
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──create_mindrecord.py // convert raw data to mindrecords
|
||||
├── model_utils
|
||||
├──config.py // Parameter config
|
||||
├──moxing_adapter.py // modelarts device configuration
|
||||
├──device_adapter.py // Device Config
|
||||
├──local_adapter.py // local device config
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
├── default_config.yaml // config file
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
```python
|
||||
Major parameters in config.py as follows:
|
||||
```default_config.yaml
|
||||
Major parameters in default_config.yaml as follows:
|
||||
|
||||
--data_root_train: The path to the raw training data images for conversion to mindrecord script.
|
||||
--data_root_test: The path to the raw test data images for conversion to mindrecord script.
|
||||
|
@ -124,7 +129,7 @@ Major parameters in config.py as follows:
|
|||
- running on Ascend
|
||||
|
||||
```python
|
||||
sh run_standalone_train_ascend.sh path-to-train-mindrecords pre-trained-chkpt(optional)
|
||||
sh scripts/run_standalone_train_ascend.sh device_id path-to-train-mindrecords pre-trained-chkpt(optional)
|
||||
```
|
||||
|
||||
The model checkpoint will be saved script/train.
|
||||
|
@ -138,11 +143,85 @@ Before running the command below, please check the checkpoint path used for eval
|
|||
- running on Ascend
|
||||
|
||||
```python
|
||||
sh run_standalone_eval_ascend.sh path-to-test-mindrecords trained-chkpt-path
|
||||
sh scripts/run_standalone_eval_ascend.sh device_id path-to-test-mindrecords trained-chkpt-path
|
||||
```
|
||||
|
||||
Results of evaluation will be printed after evaluation process is completed.
|
||||
|
||||
### [Distributed Training](#contains)
|
||||
|
||||
#### Running on Ascend
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
||||
Please follow the instructions in the link below:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.
|
||||
|
||||
Run `scripts/run_distribute_train_ascend.sh` to train the model distributed. The usage of the script is:
|
||||
|
||||
```text
|
||||
sh scripts/run_distribute_train_ascend.sh [rank_table] [train_dataset_path] [PRETRAINED_CKPT_PATH(optional)]
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the training procedure.
|
||||
|
||||
```shell
|
||||
sh scripts/run_distribute_train_ascend.sh /home/rank_table.json /home/fsns/train/
|
||||
```
|
||||
|
||||
- running on ModelArts
|
||||
- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```python
|
||||
# Example of using distributed training dpn on modelarts :
|
||||
# Data set storage method
|
||||
|
||||
# ├── FSNS # dir
|
||||
# ├── train # train dir
|
||||
# ├── train.zip # mindrecord train dataset zip
|
||||
# ├── pre_trained # predtrained dir if exists
|
||||
# ├── eval # eval dir
|
||||
# ├── test.zip # mindrecord eval dataset zip
|
||||
# ├── checkpoint # ckpt files dir
|
||||
|
||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。
|
||||
# a. set "enable_modelarts=True" 。
|
||||
# set "run_distribute=True"
|
||||
# set "save_checkpoint_path=/cache/train/outputs/"
|
||||
# set "train_dataset_path=/cache/data/train/"
|
||||
# set "pre_trained=/cache/data/pre_trained/pred file name" Without pre-training weights pre_trained=""
|
||||
#
|
||||
# b. add "enable_modelarts=True" Parameters are on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) Set the code path on the modelarts interface "/path/cnn_direction_model"。
|
||||
# (4) Set the model's startup file on the modelarts interface "train.py" 。
|
||||
# (5) Set the data path of the model on the modelarts interface ".../FSNS/train"(choices FSNS/train Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (6) start trainning the model。
|
||||
|
||||
# Example of using model inference on modelarts
|
||||
# (1) Place the trained model to the corresponding position of the bucket。
|
||||
# (2) chocie a or b。
|
||||
# a. set "enable_modelarts=True" 。
|
||||
# set "eval_dataset_path=/cache/data/test/"
|
||||
# set "checkpoint_path=/cache/data/checkpoint/checkpoint file name"
|
||||
|
||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (4) Set the code path on the modelarts interface "/path/cnn_direction_model"。
|
||||
# (5) Set the model's startup file on the modelarts interface "eval.py" 。
|
||||
# (6) Set the data path of the model on the modelarts interface ".../FSNS/eval"(choices FSNS/eval Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (7) Start model inference。
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
need_modelarts_dataset_unzip: True
|
||||
modelarts_dataset_unzip_name: "FSNS"
|
||||
|
||||
# ======================================================================================
|
||||
# common options
|
||||
run_distribute: False
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
|
||||
# create train dataset options
|
||||
train_annotation_file: ""
|
||||
data_root_train: ""
|
||||
mindrecord_dir: ""
|
||||
|
||||
# training options
|
||||
dataset_name: "fsns"
|
||||
batch_size: 8
|
||||
epoch_size: 1
|
||||
pretrain_epoch_size: 0
|
||||
save_checkpoint: True
|
||||
save_checkpoint_steps: 2500
|
||||
save_checkpoint_epochs: 1
|
||||
keep_checkpoint_max: 20
|
||||
save_checkpoint_path: "./"
|
||||
warmup_epochs: 5
|
||||
lr_decay_mode: "poly"
|
||||
lr: 5e-4
|
||||
work_nums: 4
|
||||
im_size_w: 512
|
||||
im_size_h: 64
|
||||
pos_samples_size: 100
|
||||
augment_severity: 0.1
|
||||
augment_prob: 0.3
|
||||
train_dataset_path: ""
|
||||
pre_trained: ""
|
||||
is_save_on_master: 1
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
|
||||
# create eval dataset options
|
||||
test_annotation_file: ""
|
||||
data_root_test: ""
|
||||
|
||||
# eval options
|
||||
eval_dataset_path: ""
|
||||
checkpoint_path: ""
|
||||
|
||||
# export options
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of input data"
|
||||
output_pah: "The location of the output file"
|
||||
device_target: "device id of GPU or Ascend. (Default: None)"
|
||||
enable_profiling: "Whether enable profiling while training default: False"
|
||||
|
||||
|
|
@ -13,40 +13,88 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test direction model."""
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from src.cnn_direction_model import CNNDirectionModel
|
||||
from src.config import config1 as config
|
||||
from src.dataset import create_dataset_eval
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
args_opt = parser.parse_args()
|
||||
from src.cnn_direction_model import CNNDirectionModel
|
||||
from src.dataset import create_dataset_eval
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
from src.model_utils.config import config
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def model_eval():
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||
device_id = get_device_id()
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
# create dataset
|
||||
dataset_name = config.dataset_name
|
||||
dataset_lr, dataset_rl = create_dataset_eval(args_opt.dataset_path + "/" + dataset_name +
|
||||
dataset_lr, dataset_rl = create_dataset_eval(config.eval_dataset_path + "/" + dataset_name +
|
||||
".mindrecord0", config=config, dataset_name=dataset_name)
|
||||
step_size = dataset_lr.get_dataset_size()
|
||||
|
||||
|
@ -56,7 +104,7 @@ if __name__ == '__main__':
|
|||
net = CNNDirectionModel([3, 64, 48, 48, 64], [64, 48, 48, 64, 64], [256, 64], [64, 512])
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
|
@ -69,5 +117,9 @@ if __name__ == '__main__':
|
|||
# eval model
|
||||
res_lr = model.eval(dataset_lr, dataset_sink_mode=False)
|
||||
res_rl = model.eval(dataset_rl, dataset_sink_mode=False)
|
||||
print("result on upright images:", res_lr, "ckpt=", args_opt.checkpoint_path)
|
||||
print("result on 180 degrees rotated images:", res_rl, "ckpt=", args_opt.checkpoint_path)
|
||||
print("result on upright images:", res_lr, "ckpt=", config.checkpoint_path)
|
||||
print("result on 180 degrees rotated images:", res_rl, "ckpt=", config.checkpoint_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_eval()
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# ============================================================================
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
|
||||
echo "Usage: sh scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -67,21 +67,22 @@ do
|
|||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp ./*.py ./train_parallel$i
|
||||
cp -r ./scripts ./train_parallel$i
|
||||
cp ./*yaml ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 > train.log 2>&1 &
|
||||
python train.py --run_distribute=True --train_dataset_path=$PATH2 > train.log 2>&1 &
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 > train.log 2>&1 &
|
||||
python train.py --run_distribute=True --train_dataset_path=$PATH2 --pre_trained=$PATH3 > train.log 2>&1 &
|
||||
fi
|
||||
|
||||
cd ..
|
||||
|
|
|
@ -14,15 +14,15 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH]"
|
||||
echo "Usage: sh scripts/run_standalone_train.sh [DEVICE_ID] [DATASET_PATH] [PRETRAINED_CKPT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=4
|
||||
export DEVICE_ID=$1
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
|
@ -35,8 +35,8 @@ get_real_path(){
|
|||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
PATH1=$(get_real_path $2)
|
||||
PATH2=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
|
@ -50,13 +50,14 @@ then
|
|||
fi
|
||||
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cp ./*.py ./eval
|
||||
cp -r ./scripts ./eval
|
||||
cp -r ./src ./eval
|
||||
cp ./*yaml ./eval
|
||||
cd ./eval || exit
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 > eval.log 2>&1 &
|
||||
python eval.py --eval_dataset_path=$PATH1 --checkpoint_path=$PATH2 > eval.log 2>&1 &
|
||||
|
||||
cd ..
|
||||
|
|
|
@ -14,15 +14,15 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ] && [ $# != 2 ]
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
|
||||
echo "Usage: sh scripts/run_standalone_train_ascend.sh [DEVICE_ID] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=3
|
||||
export DEVICE_ID=$1
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
|
@ -35,14 +35,14 @@ get_real_path(){
|
|||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH1=$(get_real_path $2)
|
||||
|
||||
if [ $# == 2 ]
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PATH2=$(get_real_path $2)
|
||||
PATH2=$(get_real_path $3)
|
||||
fi
|
||||
|
||||
if [ $# == 2 ] && [ ! -f $PATH2 ]
|
||||
if [ $# == 3 ] && [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
|
@ -53,20 +53,20 @@ then
|
|||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cp ./*.py ./train
|
||||
cp -r ./scripts ./train
|
||||
cp ./*yaml ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 1 ]
|
||||
then
|
||||
python train.py --dataset_path=$PATH1 &> log &
|
||||
fi
|
||||
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 > train.log 2>&1 &
|
||||
python train.py --train_dataset_path=$PATH1 &> log &
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train.py --train_dataset_path=$PATH1 --pre_trained=$PATH2 > train.log 2>&1 &
|
||||
fi
|
||||
|
||||
cd ..
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
|
||||
config1 = ed({
|
||||
# dataset metadata
|
||||
"dataset_name": "fsns",
|
||||
# annotation files paths
|
||||
"train_annotation_file": "path-to-file",
|
||||
"test_annotation_file": "path-to-file",
|
||||
# dataset root paths
|
||||
"data_root_train": "path-to-dir",
|
||||
"data_root_test": "path-to-dir",
|
||||
# mindrecord target locations
|
||||
"mindrecord_dir": "path-to-dir",
|
||||
# training and testing params
|
||||
"batch_size": 8,
|
||||
"epoch_size": 5,
|
||||
"pretrain_epoch_size": 0,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 2500,
|
||||
"save_checkpoint_epochs": 10,
|
||||
"keep_checkpoint_max": 20,
|
||||
"save_checkpoint_path": "./",
|
||||
"warmup_epochs": 5,
|
||||
"lr_decay_mode": "poly",
|
||||
"lr": 1e-4,
|
||||
"work_nums": 4,
|
||||
"im_size_w": 512,
|
||||
"im_size_h": 64,
|
||||
"pos_samples_size": 100,
|
||||
"augment_severity": 0.1,
|
||||
"augment_prob": 0.3
|
||||
})
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
|
||||
_config_path = '../../default_config.yaml'
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml
|
||||
|
||||
Args:
|
||||
parser: Parent parser
|
||||
cfg: Base configuration
|
||||
helper: Helper description
|
||||
cfg_path: Path to the default yaml config
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]',
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml')
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError('Failed to parse yaml')
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments
|
||||
|
||||
Args:
|
||||
args: command line arguments
|
||||
cfg: Base configuration
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='default name', add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, _config_path),
|
||||
help='Config file path')
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id'
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return 'Local Job'
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
|
||||
_global_syn_count = 0
|
||||
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local
|
||||
Uploca data from local directory to remote obs in contrast
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_syn_count
|
||||
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count)
|
||||
_global_syn_count += 1
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print('from path: ', from_path)
|
||||
print('to path: ', to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print('===finished data synchronization===')
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print('===save flag===')
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
print('Finish sync data from {} to {}'.format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print('Dataset downloaded: ', os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
if not os.path.exists(config.load_path):
|
||||
# os.makedirs(config.load_path)
|
||||
print('=' * 20 + 'makedirs')
|
||||
if os.path.isdir(config.load_path):
|
||||
print('=' * 20 + 'makedirs success')
|
||||
else:
|
||||
print('=' * 20 + 'makedirs fail')
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print('Preload downloaded: ', os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print('Workspace downloaded: ', os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print('Start to copy output directory')
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -13,11 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train CNN direction model."""
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
from ast import literal_eval as liter
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
@ -29,57 +30,103 @@ from mindspore.nn.optim.adam import Adam
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.cnn_direction_model import CNNDirectionModel
|
||||
from src.config import config1 as config
|
||||
from src.dataset import create_dataset_train
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
||||
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
random.seed(11)
|
||||
np.random.seed(11)
|
||||
de.config.set_seed(11)
|
||||
ms.common.set_seed(11)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
target = args_opt.device_target
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train():
|
||||
config.lr = liter(config.lr)
|
||||
target = config.device_target
|
||||
ckpt_save_dir = config.save_checkpoint_path
|
||||
|
||||
# init context
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
rank_id = int(os.getenv('RANK_ID', '0'))
|
||||
rank_size = int(os.getenv('RANK_SIZE', '1'))
|
||||
device_id = get_device_id()
|
||||
rank_id = get_rank_id()
|
||||
rank_size = get_device_num()
|
||||
run_distribute = rank_size > 1
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
device_target=target,
|
||||
device_id=device_id, save_graphs=False)
|
||||
|
||||
print("train args: ", args_opt, "\ncfg: ", config,
|
||||
print("train args: ", config, "\ncfg: ", config,
|
||||
"\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
|
||||
|
||||
if run_distribute:
|
||||
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
init()
|
||||
|
||||
args_opt.rank_save_ckpt_flag = 0
|
||||
if args_opt.is_save_on_master:
|
||||
config.rank_save_ckpt_flag = 0
|
||||
if config.is_save_on_master:
|
||||
if rank_id == 0:
|
||||
args_opt.rank_save_ckpt_flag = 1
|
||||
config.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args_opt.rank_save_ckpt_flag = 1
|
||||
config.rank_save_ckpt_flag = 1
|
||||
|
||||
# create dataset
|
||||
dataset_name = config.dataset_name
|
||||
dataset = create_dataset_train(args_opt.dataset_path + "/" + dataset_name +
|
||||
dataset = create_dataset_train(config.train_dataset_path + "/" + dataset_name +
|
||||
".mindrecord0", config=config, dataset_name=dataset_name)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
|
@ -87,8 +134,8 @@ if __name__ == '__main__':
|
|||
net = CNNDirectionModel([3, 64, 48, 48, 64], [64, 48, 48, 64, 64], [256, 64], [64, 512])
|
||||
|
||||
# init weight
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
if config.pre_trained:
|
||||
param_dict = load_checkpoint(config.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
lr = config.lr
|
||||
|
@ -107,7 +154,7 @@ if __name__ == '__main__':
|
|||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint:
|
||||
if args_opt.rank_save_ckpt_flag == 1:
|
||||
if config.rank_save_ckpt_flag == 1:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck)
|
||||
|
@ -115,3 +162,7 @@ if __name__ == '__main__':
|
|||
|
||||
# train model
|
||||
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
# ======================================================================================
|
||||
# common options
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
CHARACTER: "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
# NUM_CLASS = len(CHARACTER) + 1
|
||||
NUM_CLASS: 37
|
||||
|
||||
HIDDEN_SIZE: 512
|
||||
FINAL_FEATURE_WIDTH: 26
|
||||
|
||||
# dataset config
|
||||
IMG_H: 32
|
||||
IMG_W: 100
|
||||
TRAIN_DATASET_PATH: "CNNCTC_Data/ST_MJ/"
|
||||
TRAIN_DATASET_INDEX_PATH: "CNNCTC_Data/st_mj_fixed_length_index_list.pkl"
|
||||
TRAIN_BATCH_SIZE: 192
|
||||
TRAIN_EPOCHS: 3
|
||||
|
||||
# training config
|
||||
run_distribute: False
|
||||
PRED_TRAINED: ""
|
||||
SAVE_PATH: "./"
|
||||
LR: 1e-4
|
||||
LR_PARA: 5e-4
|
||||
MOMENTUM: 0.8
|
||||
LOSS_SCALE: 8096
|
||||
SAVE_CKPT_PER_N_STEP: 2000
|
||||
KEEP_CKPT_MAX_NUM: 5
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
TEST_DATASET_PATH: "CNNCTC_Data/IIIT5k_3000"
|
||||
TEST_BATCH_SIZE: 256
|
||||
CHECKPOINT_PATH: ""
|
||||
|
||||
# export options
|
||||
device_id: 0
|
||||
file_name: "cnnctc"
|
||||
file_format: "MINDIR"
|
||||
ckpt_file: ""
|
||||
|
||||
# 310 infer
|
||||
result_path: ""
|
||||
label_path: ""
|
||||
preprocess_output: ""
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of input data"
|
||||
output_pah: "The location of the output file"
|
||||
device_target: "device id of GPU or Ascend. (Default: None)"
|
||||
enable_profiling: "Whether enable profiling while training default: False"
|
||||
file_name: "CNN&CTC output air name"
|
||||
file_format: "choices [AIR, MINDIR]"
|
||||
ckpt_file: "CNN&CTC ckpt file"
|
||||
|
|
@ -14,34 +14,36 @@
|
|||
# ============================================================================
|
||||
"""cnnctc eval"""
|
||||
|
||||
import argparse
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.dataset import GeneratorDataset
|
||||
|
||||
from src.util import CTCLabelConverter, AverageMeter
|
||||
from src.config import Config_CNNCTC
|
||||
from src.dataset import IIIT_Generator_batch
|
||||
from src.cnn_ctc import CNNCTC_Model
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||
save_graphs_path=".", enable_auto_mixed_precision=False)
|
||||
|
||||
|
||||
def test_dataset_creator():
|
||||
ds = GeneratorDataset(IIIT_Generator_batch, ['img', 'label_indices', 'text', 'sequence_length', 'label_str'])
|
||||
return ds
|
||||
|
||||
|
||||
def test(config):
|
||||
@moxing_wrapper(pre_process=None)
|
||||
def test():
|
||||
ds = test_dataset_creator()
|
||||
|
||||
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||
|
||||
ckpt_path = config.CKPT_PATH
|
||||
ckpt_path = config.CHECKPOINT_PATH
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('parameters loaded! from: ', ckpt_path)
|
||||
|
@ -98,12 +100,4 @@ def test(config):
|
|||
print('accuracy: ', correct_count / count)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="FasterRcnn training")
|
||||
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--ckpt_path", type=str, default="", help="trained file path.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
cfg = Config_CNNCTC()
|
||||
if args_opt.ckpt_path != "":
|
||||
cfg.CKPT_PATH = args_opt.ckpt_path
|
||||
test(cfg)
|
||||
test()
|
||||
|
|
|
@ -12,42 +12,38 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
import argparse
|
||||
"""export checkpoint file into air, onnx, mindir models
|
||||
suggest run as python export.py --filename cnnctc --file_format MINDIR --ckpt_file [ckpt file path]
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context, load_checkpoint, export
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from src.config import Config_CNNCTC
|
||||
from src.cnn_ctc import CNNCTC_Model
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
parser = argparse.ArgumentParser(description="CNNCTC_export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--file_name", type=str, default="cnn_ctc", help="CNN&CTC output air name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
parser.add_argument("--ckpt_file", type=str, default="./ckpts/cnn_ctc.ckpt", help="CNN&CTC ckpt file.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = Config_CNNCTC()
|
||||
ckpt_path = cfg.CKPT_PATH
|
||||
|
||||
if args_opt.ckpt_file != "":
|
||||
ckpt_path = args_opt.ckpt_file
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
net = CNNCTC_Model(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
|
||||
|
||||
load_checkpoint(ckpt_path, net=net)
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def model_export():
|
||||
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||
|
||||
bs = cfg.TEST_BATCH_SIZE
|
||||
load_checkpoint(config.ckpt_file, net=net)
|
||||
|
||||
input_data = Tensor(np.zeros([bs, 3, cfg.IMG_H, cfg.IMG_W]), mstype.float32)
|
||||
bs = config.TEST_BATCH_SIZE
|
||||
|
||||
export(net, input_data, file_name=args_opt.file_name, file_format=args_opt.file_format)
|
||||
input_data = Tensor(np.zeros([bs, 3, config.IMG_H, config.IMG_W]), mstype.float32)
|
||||
|
||||
export(net, input_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_export()
|
||||
|
|
|
@ -14,6 +14,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ] && [ $# != 2 ]
|
||||
then
|
||||
echo "run as scripts/run_distribute_train_ascend.sh RANK_TABLE_FILE PRED_TRAINED(options)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
|
@ -41,15 +47,16 @@ do
|
|||
cp ./*.py ./train_parallel_$i
|
||||
cp ./scripts/*.sh ./train_parallel_$i
|
||||
cp -r ./src ./train_parallel_$i
|
||||
cp ./*yaml ./train_parallel_$i
|
||||
cd ./train_parallel_$i || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
if [ -f $PATH2 ]
|
||||
then
|
||||
python train.py --device_id=$i --ckpt_path=$PATH2 --run_distribute=True >log_$i.log 2>&1 &
|
||||
python train.py --PRED_TRAINED=$PATH2 --run_distribute=True >log_$i.log 2>&1 &
|
||||
else
|
||||
python train.py --device_id=$i --run_distribute=True >log_$i.log 2>&1 &
|
||||
python train.py --run_distribute=True >log_$i.log 2>&1 &
|
||||
fi
|
||||
cd .. || exit
|
||||
done
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 1 ]
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [TRAINED_CKPT]"
|
||||
echo "Usage: sh scripts/run_eval_ascend.sh [DEVICE_ID] [TRAINED_CKPT]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -28,7 +28,7 @@ get_real_path(){
|
|||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH1=$(get_real_path $2)
|
||||
echo $PATH1
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
|
@ -37,7 +37,7 @@ exit 1
|
|||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_ID=0
|
||||
export DEVICE_ID=$1
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
|
@ -47,8 +47,9 @@ mkdir ./eval
|
|||
cp ./*.py ./eval
|
||||
cp ./scripts/*.sh ./eval
|
||||
cp -r ./src ./eval
|
||||
cp ./*yaml ./eval
|
||||
cd ./eval || exit
|
||||
echo "start inferring for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python eval.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 &> log &
|
||||
python eval.py --CHECKPOINT_PATH=$PATH1 &> log &
|
||||
cd .. || exit
|
||||
|
|
|
@ -14,6 +14,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ] && [ $# != 2 ]
|
||||
then
|
||||
echo "run as sh scripts/run_standalone_train_ascend.sh DEVICE_ID PRE_TRAINED(options)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
|
@ -21,7 +27,9 @@ get_real_path(){
|
|||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH1=$(get_real_path $2)
|
||||
|
||||
export DEVICE_ID=$1
|
||||
|
||||
ulimit -u unlimited
|
||||
|
||||
|
@ -33,13 +41,14 @@ mkdir ./train
|
|||
cp ./*.py ./train
|
||||
cp ./scripts/*.sh ./train
|
||||
cp -r ./src ./train
|
||||
cp ./*yaml ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ -f $PATH1 ]
|
||||
then
|
||||
python train.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 --run_distribute=False &> log &
|
||||
python train.py --PRED_TRAINED=$PATH1 --run_distribute=False &> log &
|
||||
else
|
||||
python train.py --device_id=$DEVICE_ID --run_distribute=False &> log &
|
||||
python train.py --run_distribute=False &> log &
|
||||
fi
|
||||
cd .. || exit
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""network config setting, will be used in train.py and eval.py"""
|
||||
|
||||
class Config_CNNCTC():
|
||||
# model config
|
||||
CHARACTER = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||
NUM_CLASS = len(CHARACTER) + 1
|
||||
HIDDEN_SIZE = 512
|
||||
FINAL_FEATURE_WIDTH = 26
|
||||
|
||||
# dataset config
|
||||
IMG_H = 32
|
||||
IMG_W = 100
|
||||
TRAIN_DATASET_PATH = 'CNNCTC_Data/ST_MJ/'
|
||||
TRAIN_DATASET_INDEX_PATH = 'CNNCTC_Data/st_mj_fixed_length_index_list.pkl'
|
||||
TRAIN_BATCH_SIZE = 192
|
||||
TEST_DATASET_PATH = 'CNNCTC_Data/IIIT5k_3000'
|
||||
TEST_BATCH_SIZE = 256
|
||||
TRAIN_EPOCHS = 3
|
||||
|
||||
# training config
|
||||
CKPT_PATH = ''
|
||||
SAVE_PATH = './'
|
||||
LR = 1e-4
|
||||
LR_PARA = 5e-4
|
||||
MOMENTUM = 0.8
|
||||
LOSS_SCALE = 8096
|
||||
SAVE_CKPT_PER_N_STEP = 2000
|
||||
KEEP_CKPT_MAX_NUM = 5
|
|
@ -21,13 +21,10 @@ import six
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
import lmdb
|
||||
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
|
||||
from src.model_utils.config import config
|
||||
from .util import CTCLabelConverter
|
||||
from .config import Config_CNNCTC
|
||||
|
||||
config = Config_CNNCTC()
|
||||
|
||||
class NormalizePAD():
|
||||
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
|
||||
_config_path = '../../default_config.yaml'
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml
|
||||
|
||||
Args:
|
||||
parser: Parent parser
|
||||
cfg: Base configuration
|
||||
helper: Helper description
|
||||
cfg_path: Path to the default yaml config
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]',
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml')
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError('Failed to parse yaml')
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments
|
||||
|
||||
Args:
|
||||
args: command line arguments
|
||||
cfg: Base configuration
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='default name', add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, _config_path),
|
||||
help='Config file path')
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id'
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return 'Local Job'
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
|
||||
_global_syn_count = 0
|
||||
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local
|
||||
Uploca data from local directory to remote obs in contrast
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_syn_count
|
||||
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count)
|
||||
_global_syn_count += 1
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print('from path: ', from_path)
|
||||
print('to path: ', to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print('===finished data synchronization===')
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print('===save flag===')
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
print('Finish sync data from {} to {}'.format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print('Dataset downloaded: ', os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
if not os.path.exists(config.load_path):
|
||||
# os.makedirs(config.load_path)
|
||||
print('=' * 20 + 'makedirs')
|
||||
if os.path.isdir(config.load_path):
|
||||
print('=' * 20 + 'makedirs success')
|
||||
else:
|
||||
print('=' * 20 + 'makedirs fail')
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print('Preload downloaded: ', os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print('Workspace downloaded: ', os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print('Start to copy output directory')
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -14,9 +14,8 @@
|
|||
# ============================================================================
|
||||
"""cnnctc train"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
import ast
|
||||
import mindspore
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
@ -25,14 +24,17 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import Config_CNNCTC
|
||||
from src.callback import LossCallBack
|
||||
from src.dataset import ST_MJ_Generator_batch_fixed_length, ST_MJ_Generator_batch_fixed_length_para
|
||||
from src.cnn_ctc import CNNCTC_Model, ctc_loss, WithLossCell
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||
save_graphs_path=".", enable_auto_mixed_precision=False)
|
||||
|
||||
|
@ -50,18 +52,27 @@ def dataset_creator(run_distribute):
|
|||
return ds
|
||||
|
||||
|
||||
def train(args_opt, config):
|
||||
if args_opt.run_distribute:
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train():
|
||||
device_id = get_device_id()
|
||||
if config.run_distribute:
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel")
|
||||
|
||||
ds = dataset_creator(args_opt.run_distribute)
|
||||
config.LR = ast.literal_eval(config.LR)
|
||||
config.LR_PARA = ast.literal_eval(config.LR_PARA)
|
||||
|
||||
ds = dataset_creator(config.run_distribute)
|
||||
|
||||
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||
net.set_train(True)
|
||||
|
||||
if config.CKPT_PATH != '':
|
||||
param_dict = load_checkpoint(config.CKPT_PATH)
|
||||
if config.PRED_TRAINED:
|
||||
param_dict = load_checkpoint(config.PRED_TRAINED)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('parameters loaded!')
|
||||
else:
|
||||
|
@ -80,8 +91,8 @@ def train(args_opt, config):
|
|||
keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=config.SAVE_PATH)
|
||||
|
||||
if args_opt.run_distribute:
|
||||
if args_opt.device_id == 0:
|
||||
if config.run_distribute:
|
||||
if device_id == 0:
|
||||
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False)
|
||||
else:
|
||||
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False)
|
||||
|
@ -90,14 +101,4 @@ def train(args_opt, config):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='CNNCTC arg')
|
||||
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--ckpt_path", type=str, default="", help="Pretrain file path.")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is false.")
|
||||
args_cfg = parser.parse_args()
|
||||
|
||||
cfg = Config_CNNCTC()
|
||||
if args_cfg.ckpt_path != "":
|
||||
cfg.CKPT_PATH = args_cfg.ckpt_path
|
||||
train(args_cfg, cfg)
|
||||
train()
|
||||
|
|
|
@ -66,7 +66,7 @@ After installing MindSpore via the official website, you can start training and
|
|||
|
||||
```python
|
||||
# run distributed training example
|
||||
sh scripts/run_distribute_train.sh rank_table_file pretrained_model.ckpt
|
||||
sh scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
|
||||
|
||||
#download opencv library
|
||||
download pyblind11, opencv3.4
|
||||
|
@ -79,7 +79,7 @@ setup opencv3.4(compile source code install the library)
|
|||
cd ./src/ETSNET/pse/;make
|
||||
|
||||
#run test.py
|
||||
python test.py --ckpt=pretrained_model.ckpt
|
||||
python test.py --ckpt pretrained_model.ckpt --TEST_ROOT_DIR [test root path]
|
||||
|
||||
#download eval method from [here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization).
|
||||
#click "My Methods" button,then download Evaluation Scripts
|
||||
|
@ -95,6 +95,7 @@ sh scripts/run_eval_ascend.sh
|
|||
```path
|
||||
└── PSENet
|
||||
├── export.py // export mindir file
|
||||
├── postprocess.py // 310 Inference post-processing script
|
||||
├── __init__.py
|
||||
├── mindspore_hub_conf.py // hub config file
|
||||
├── README_CN.md // descriptions about PSENet in Chinese
|
||||
|
@ -102,8 +103,13 @@ sh scripts/run_eval_ascend.sh
|
|||
├── scripts
|
||||
├── run_distribute_train.sh // shell script for distributed
|
||||
└── run_eval_ascend.sh // shell script for evaluation
|
||||
├── ascend310_infer // application for 310 inference
|
||||
├── src
|
||||
├── config.py // parameter configuration
|
||||
├── model_utils
|
||||
├──config.py // Parameter config
|
||||
├──moxing_adapter.py // modelarts device configuration
|
||||
├──device_adapter.py // Device Config
|
||||
├──local_adapter.py // local device config
|
||||
├── dataset.py // creating dataset
|
||||
├── ETSNET
|
||||
├── base.py // convolution and BN operator
|
||||
|
@ -122,19 +128,18 @@ sh scripts/run_eval_ascend.sh
|
|||
├── network_define.py // learning ratio generation
|
||||
├── test.py // test script
|
||||
├── train.py // training script
|
||||
|
||||
├── default_config.yaml // config file
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
```python
|
||||
Major parameters in train.py and config.py are:
|
||||
```default_config.yaml
|
||||
Major parameters in default_config.yaml are:
|
||||
|
||||
--pre_trained: Whether training from scratch or training based on the
|
||||
pre-trained model.Optional values are True, False.
|
||||
--device_id: Device ID used to train or evaluate the dataset. Ignore it
|
||||
when you use train.sh for distributed training.
|
||||
--device_num: devices used when you use train.sh for distributed training.
|
||||
|
||||
```
|
||||
|
||||
|
@ -142,8 +147,12 @@ Major parameters in train.py and config.py are:
|
|||
|
||||
### Distributed Training
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
||||
Please follow the instructions in the link below: <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.
|
||||
|
||||
```shell
|
||||
sh scripts/run_distribute_train.sh rank_table_file pretrained_model.ckpt
|
||||
sh scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
|
||||
```
|
||||
|
||||
rank_table_file which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
@ -164,7 +173,66 @@ device_1/log:epcoh: 2, step: 40, loss is 0.76629
|
|||
|
||||
### run test code
|
||||
|
||||
python test.py --ckpt=./device*/ckpt*/ETSNet-*.ckpt
|
||||
```test
|
||||
python test.py --ckpt [CKPK PATH] --TEST_ROOT_DIR [TEST DATA DIR]
|
||||
|
||||
```
|
||||
|
||||
- running on ModelArts
|
||||
- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```python
|
||||
# Example of using distributed training on modelarts :
|
||||
# Data set storage method
|
||||
|
||||
# ├── ICDAR2015 # dir
|
||||
# ├── train # train dir
|
||||
# ├── ic15 # train_dataset dir
|
||||
# ├── ch4_training_images
|
||||
# ├── ch4_training_localization_transcription_gt
|
||||
# ├── train_predtrained # predtrained dir
|
||||
# ├── eval # eval dir
|
||||
# ├── ic15 # eval dataset dir
|
||||
# ├── ch4_test_images
|
||||
# ├── challenge4_Test_Task1_GT
|
||||
# ├── checkpoint # ckpt files dir
|
||||
|
||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。
|
||||
# a. set "enable_modelarts=True" 。
|
||||
# set "run_distribute=True"
|
||||
# set "TRAIN_MODEL_SAVE_PATH=/cache/train/outputs_imagenet/"
|
||||
# set "TRAIN_ROOT_DIR=/cache/data/ic15/"
|
||||
# set "pre_trained=/cache/data/train_predtrained/pred file name" Without pre-training weights train_pretrained=""
|
||||
|
||||
# b. add "enable_modelarts=True" Parameters are on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) Set the code path on the modelarts interface "/path/psenet"。
|
||||
# (4) Set the model's startup file on the modelarts interface "train.py" 。
|
||||
# (5) Set the data path of the model on the modelarts interface ".../ICDAR2015/train"(choices ICDAR2015/train Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (6) start trainning the model。
|
||||
|
||||
# Example of using model inference on modelarts
|
||||
# (1) Place the trained model to the corresponding position of the bucket。
|
||||
# (2) chocie a or b。
|
||||
# a. set "enable_modelarts=True" 。
|
||||
# set "TEST_ROOT_DIR=/cache/data/ic15/"
|
||||
# set "ckpt=/cache/data/checkpoint/ckpt file"
|
||||
|
||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (4) Set the code path on the modelarts interface "/path/psenet"。
|
||||
# (5) Set the model's startup file on the modelarts interface "eval.py" 。
|
||||
# (6) Set the data path of the model on the modelarts interface ".../ICDAR2015/eval"(choices ICDAR2015/eval Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (7) Start model inference。
|
||||
```
|
||||
|
||||
### Eval Script for ICDAR2015
|
||||
|
||||
|
@ -189,12 +257,33 @@ Calculated!{"precision": 0.814796668299853, "recall": 0.8006740491092923, "hmean
|
|||
### [Export MindIR](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
python export.py --ckpt [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
The ckpt_file parameter is required,
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
|
||||
- Export MindIR on Modelarts
|
||||
|
||||
```Modelarts
|
||||
Export MindIR example on ModelArts
|
||||
Data storage method is the same as training
|
||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters)。
|
||||
# a. set "enable_modelarts=True"
|
||||
# set "file_name=/cache/train/psenet"
|
||||
# set "file_format=MINDIR"
|
||||
# set "ckpt_file=/cache/data/checkpoint file name"
|
||||
|
||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
# (2)Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) Set the code path on the modelarts interface "/path/psenet"。
|
||||
# (4) Set the model's startup file on the modelarts interface "export.py" 。
|
||||
# (5) Set the data path of the model on the modelarts interface ".../ICDAR2015/eval/checkpoint"(choices ICDAR2015/eval/checkpoint Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
```
|
||||
|
||||
### Infer on Ascend310
|
||||
|
||||
Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model.
|
||||
|
|
|
@ -68,7 +68,7 @@
|
|||
|
||||
```python
|
||||
# 分布式训练运行示例
|
||||
sh scripts/run_distribute_train.sh pretrained_model.ckpt
|
||||
sh scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
|
||||
|
||||
# 下载opencv库
|
||||
download pyblind11, opencv3.4
|
||||
|
@ -77,14 +77,16 @@ download pyblind11, opencv3.4
|
|||
setup pyblind11(install the library by the pip command)
|
||||
setup opencv3.4(compile source code install the library)
|
||||
|
||||
# 输入路径,运行Makefile,找到产品文件
|
||||
cd ./src/ETSNET/pse/;make
|
||||
|
||||
# 运行test.py
|
||||
python test.py --ckpt=pretrained_model.ckpt
|
||||
|
||||
# 单击[此处](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization)下载评估方法
|
||||
# 点击"我的方法"按钮,下载评估脚本
|
||||
|
||||
# 输入路径,运行Makefile,找到产品文件
|
||||
cd ./src/ETSNET/pse/;make clean&&make
|
||||
|
||||
# 运行test.py
|
||||
python test.py --ckpt pretrained_model.ckpt --TEST_ROOT_DIR [test root path]
|
||||
|
||||
|
||||
download script.py
|
||||
# 运行评估示例
|
||||
sh scripts/run_eval_ascend.sh
|
||||
|
@ -98,13 +100,19 @@ sh scripts/run_eval_ascend.sh
|
|||
└── PSENet
|
||||
├── export.py // mindir转换脚本
|
||||
├── mindspore_hub_conf.py // 网络模型
|
||||
├─postprogress.py # 310推理后处理脚本
|
||||
├── README.md // PSENet相关描述英文版
|
||||
├── README_CN.md // PSENet相关描述中文版
|
||||
├── scripts
|
||||
├── run_distribute_train.sh // 用于分布式训练的shell脚本
|
||||
└── run_eval_ascend.sh // 用于评估的shell脚本
|
||||
├─run_infer_310.sh # Ascend 310 推理shell脚本
|
||||
├── src
|
||||
├── config.py // 参数配置
|
||||
├──model_utils
|
||||
├──config.py # 参数配置
|
||||
├──device_adapter.py # 设备相关信息
|
||||
├──local_adapter.py # 设备相关信息
|
||||
├──moxing_adapter.py # 装饰器(主要用于ModelArts数据拷贝)
|
||||
├── dataset.py // 创建数据集
|
||||
├── ETSNET
|
||||
├── base.py // 卷积和BN算子
|
||||
|
@ -123,26 +131,29 @@ sh scripts/run_eval_ascend.sh
|
|||
├── network_define.py // PSENet架构
|
||||
├── test.py // 测试脚本
|
||||
├── train.py // 训练脚本
|
||||
|
||||
├─default_config.yaml # 参数文件
|
||||
├─ma-pre-start.sh # modelarts配置系统环境变量
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
```python
|
||||
train.py和config.py中主要参数如下:
|
||||
```default_config.yaml
|
||||
配置文件中主要参数如下:
|
||||
|
||||
-- pre_trained:是从零开始训练还是基于预训练模型训练。可选值为True、False。
|
||||
-- device_id:用于训练或评估数据集的设备ID。当使用train.sh进行分布式训练时,忽略此参数。
|
||||
-- device_num:使用train.sh进行分布式训练时使用的设备。
|
||||
|
||||
-- device_id:用于训练或评估数据集或导出的设备ID。当使用train.sh进行分布式训练时,忽略此参数。
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 分布式训练
|
||||
|
||||
分布式训练需要提前创建JSON格式的HCCL配置文件。
|
||||
|
||||
请遵循链接中的说明:[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
|
||||
|
||||
```shell
|
||||
sh scripts/run_distribute_train.sh pretrained_model.ckpt
|
||||
sh scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。可以通过`device[X]/test_*.log`文件查看结果。
|
||||
|
@ -162,7 +173,65 @@ device_1/log:epcoh: 2, step: 40,loss is 0.76629
|
|||
|
||||
### 运行测试代码
|
||||
|
||||
python test.py --ckpt=./device*/ckpt*/ETSNet-*.ckpt
|
||||
```test
|
||||
python test.py --ckpt [CKPK PATH] --TEST_ROOT_DIR [TEST DATA DIR]
|
||||
|
||||
```
|
||||
|
||||
- 如果要在modelarts上进行模型的训练,可以参考modelarts的[官方指导文档](https://support.huaweicloud.com/modelarts/) 开始进行模型的训练和推理,具体操作如下:
|
||||
|
||||
```ModelArts
|
||||
# 在ModelArts上使用分布式训练示例:
|
||||
# 数据集存放方式
|
||||
|
||||
# ├── ICDAR2015 # dir
|
||||
# ├── train # train dir
|
||||
# ├── ic15 # train_dataset dir
|
||||
# ├── ch4_training_images
|
||||
# ├── ch4_training_localization_transcription_gt
|
||||
# ├── train_predtrained # predtrained dir
|
||||
# ├── eval # eval dir
|
||||
# ├── ic15 # eval dataset dir
|
||||
# ├── ch4_test_images
|
||||
# ├── challenge4_Test_Task1_GT
|
||||
# ├── checkpoint # ckpt files dir
|
||||
|
||||
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
|
||||
# a. 设置 "enable_modelarts=True"
|
||||
# 设置 "run_distribute=True"
|
||||
# 设置 "TRAIN_MODEL_SAVE_PATH=/cache/train/outputs/"
|
||||
# 设置 "TRAIN_ROOT_DIR=/cache/data/ic15/"
|
||||
# 设置 "pre_trained=/cache/data/train_predtrained/pred file name" 如果没有预训练权重 pre_trained=""
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
|
||||
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) 在modelarts的界面上设置代码的路径 "/path/psenet"。
|
||||
# (4) 在modelarts的界面上设置模型的启动文件 "train.py" 。
|
||||
# (5) 在modelarts的界面上设置模型的数据路径 ".../ICDAR2015/train"(选择ICDAR2015/train文件夹路径) ,
|
||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
# (6) 开始模型的训练。
|
||||
|
||||
# 在modelarts上使用模型推理的示例
|
||||
# (1) 把训练好的模型地方到桶的对应位置。
|
||||
# (2) 选择a或者b其中一种方式。
|
||||
# a.设置 "enable_modelarts=True"
|
||||
# 设置 "TEST_ROOT_DIR=/cache/data/ic15"
|
||||
# 设置 "ckpt=/cache/data/checkpoint/ckpt file"
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
|
||||
# (3) 设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (4) 在modelarts的界面上设置代码的路径 "/path/psenet"。
|
||||
# (5) 在modelarts的界面上设置模型的启动文件 "eval.py" 。
|
||||
# (6) 在modelarts的界面上设置模型的数据路径 "../ICDAR2015/eval"(选择ICDAR2015/eval文件夹路径) ,
|
||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
# (7) 开始模型的推理。
|
||||
```
|
||||
|
||||
### ICDAR2015评估脚本
|
||||
|
||||
|
@ -187,12 +256,33 @@ Calculated!{"precision": 0.8147966668299853,"recall":0.8006740491092923,"h
|
|||
### [导出MindIR](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
python export.py --ckpt [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
参数ckpt_file为必填项,
|
||||
参数ckpt为必填项,
|
||||
`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中选择。
|
||||
|
||||
- 在modelarts上导出MindIR
|
||||
|
||||
```Modelarts
|
||||
在ModelArts上导出MindIR示例
|
||||
数据集存放方式同Modelart训练
|
||||
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
|
||||
# a. 设置 "enable_modelarts=True"
|
||||
# 设置 "file_name=/cache/train/psenet"
|
||||
# 设置 "file_format=MINDIR"
|
||||
# 设置 "ckpt_file=/cache/data/checkpoint file name"
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) 在modelarts的界面上设置代码的路径 "/path/psenet"。
|
||||
# (4) 在modelarts的界面上设置模型的启动文件 "export.py" 。
|
||||
# (5) 在modelarts的界面上设置模型的数据路径 ".../ICDAR2015/eval/checkpoint"(选择ICDAR2015/eval/checkpoint文件夹路径) ,
|
||||
# MindIR的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
```
|
||||
|
||||
### 在Ascend310执行推理
|
||||
|
||||
在执行推理前,mindir文件必须通过`export.py`脚本导出。以下展示了使用minir模型执行推理的示例。
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
checkpoint_path: "./checkpoint/"
|
||||
checkpoint_file: "./checkpoint/.ckpt"
|
||||
modelarts_home: "/home/work/user-job-dir"
|
||||
object_name: "psenet"
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
pre_trained: ""
|
||||
INFER_LONG_SIZE: 1920
|
||||
KERNEL_NUM: 7
|
||||
run_distribute: False
|
||||
|
||||
# backbone
|
||||
BACKBONE_LAYER_NUMS: [3, 4, 6, 3]
|
||||
BACKBONE_IN_CHANNELS: [64, 256, 512, 1024]
|
||||
BACKBONE_OUT_CHANNELS: [256, 512, 1024, 2048]
|
||||
|
||||
# neck
|
||||
NECK_OUT_CHANNEL: 256
|
||||
|
||||
# lr
|
||||
BASE_LR: 1e-3
|
||||
TRAIN_TOTAL_ITER: 58000
|
||||
WARMUP_STEP: 620
|
||||
WARMUP_RATIO: 1/3
|
||||
|
||||
# dataset for train
|
||||
TRAIN_ROOT_DIR: ""
|
||||
TRAIN_LONG_SIZE: 640
|
||||
TRAIN_MIN_SCALE: 0.4
|
||||
TRAIN_BATCH_SIZE: 4
|
||||
TRAIN_REPEAT_NUM: 1800
|
||||
TRAIN_DROP_REMAINDER: True
|
||||
TRAIN_MODEL_SAVE_PATH: "./"
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
ckpt: ""
|
||||
TEST_ROOT_DIR: ""
|
||||
TEST_BUFFER_SIZE: 4
|
||||
TEST_DROP_REMAINDER: False
|
||||
INFERENCE: True
|
||||
|
||||
|
||||
#export options
|
||||
device_id: 0
|
||||
batch_size: 1
|
||||
file_name: "psenet"
|
||||
file_format: "MINDIR"
|
||||
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of input data"
|
||||
output_pah: "The location of the output file"
|
||||
device_target: "device id of GPU or Ascend. (Default: None)"
|
||||
enable_profiling: "Whether enable profiling while training default: False"
|
||||
run_distribute: "Run distribute, default is false."
|
||||
pre_trained: "Pretrain file path"
|
||||
ckpt: "trained model path"
|
||||
device_id: "device id"
|
||||
batch_size: "batch size"
|
||||
file_name: "output file name"
|
||||
file_format: "file format choices[AIR, MINDIR, ONNX]"
|
||||
object_home: "your direction name"
|
||||
modelarts_home: "modelarts working path"
|
|
@ -15,33 +15,29 @@
|
|||
"""
|
||||
##############export checkpoint file into air models#################
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
from src.ETSNET.etsnet import ETSNet
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
parser = argparse.ArgumentParser(description="psenet export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="psenet", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@moxing_wrapper(pre_process=None)
|
||||
def model_export():
|
||||
net = ETSNet(config)
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
param_dict = load_checkpoint(config.ckpt)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.ones([args.batch_size, 3, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE]), ms.float32)
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
||||
input_arr = Tensor(np.ones([config.batch_size, 3, config.INFER_LONG_SIZE, config.INFER_LONG_SIZE]), ms.float32)
|
||||
export(net, input_arr, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_export()
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unlesee required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========================================================================================================================
|
||||
|
||||
export ASCEND_GLOBAL_LOG_LEVEL=3
|
||||
export ASCEND_SLOG_PRINT_TO_STDOUT=0
|
||||
export ASCEND_GLOBAL_EVENT_ENABLE=0
|
||||
export OPENCV_HOME=/usr/local
|
||||
export CPLUS_INCLUDE_PATH=/usr/local/include
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64:/usr/local/lib:/usr/local/include
|
||||
export MINDSPORE_HOME=/home/work/user-job-dir/psenet/mindspore
|
|
@ -17,9 +17,9 @@
|
|||
current_exec_path=$(pwd)
|
||||
echo 'current_exec_path: '${current_exec_path}
|
||||
|
||||
if [ $# != 2 ]
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_FILE] [PRETRAINED_PATH]"
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_FILE] [PRETRAINED_PATH] [TRAIN_ROOT_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -74,7 +74,6 @@ do
|
|||
cd ${current_exec_path}/device_$i || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python ${current_exec_path}/train.py --run_distribute --device_id $i --pre_trained $PATH2 --device_num ${DEVICE_NUM} >test_deep$i.log 2>&1 &
|
||||
python ${current_exec_path}/train.py --run_distribute=True --pre_trained $PATH2 --TRAIN_ROOT_DIR=$3 >test_deep$i.log 2>&1 &
|
||||
cd ${current_exec_path} || exit
|
||||
done
|
||||
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"INFER_LONG_SIZE": 1920,
|
||||
"KERNEL_NUM": 7,
|
||||
"INFERENCE": True, # INFER MODE\TRAIN MODE
|
||||
|
||||
# backbone
|
||||
"BACKBONE_LAYER_NUMS": [3, 4, 6, 3],
|
||||
"BACKBONE_IN_CHANNELS": [64, 256, 512, 1024],
|
||||
"BACKBONE_OUT_CHANNELS": [256, 512, 1024, 2048],
|
||||
|
||||
# neck
|
||||
"NECK_OUT_CHANNEL": 256,
|
||||
|
||||
# lr
|
||||
"BASE_LR": 2e-3,
|
||||
"TRAIN_TOTAL_ITER": 58000,
|
||||
"WARMUP_STEP": 620,
|
||||
"WARMUP_RATIO": 1/3,
|
||||
|
||||
# dataset for train
|
||||
"TRAIN_ROOT_DIR": "psenet/ic15/",
|
||||
"TRAIN_LONG_SIZE": 640,
|
||||
"TRAIN_MIN_SCALE": 0.4,
|
||||
"TRAIN_BATCH_SIZE": 4,
|
||||
"TRAIN_REPEAT_NUM": 1800,
|
||||
"TRAIN_DROP_REMAINDER": True,
|
||||
"TRAIN_MODEL_SAVE_PATH": "./checkpoints/",
|
||||
|
||||
# dataset for test
|
||||
"TEST_ROOT_DIR": "psenet/ic15/",
|
||||
"TEST_BUFFER_SIZE": 4,
|
||||
"TEST_DROP_REMAINDER": False,
|
||||
})
|
|
@ -22,11 +22,9 @@ from PIL import Image
|
|||
import numpy as np
|
||||
import Polygon as plg
|
||||
import pyclipper
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.py_transforms as py_transforms
|
||||
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
|
||||
__all__ = ['train_dataset_creator', 'test_dataset_creator']
|
||||
|
||||
|
@ -179,9 +177,9 @@ class TrainDataset:
|
|||
self.kernel_num = config.KERNEL_NUM
|
||||
self.min_scale = config.TRAIN_MIN_SCALE
|
||||
|
||||
root_dir = os.path.join(os.path.join(os.path.dirname(__file__), '..'), config.TRAIN_ROOT_DIR)
|
||||
ic15_train_data_dir = root_dir + 'ch4_training_images/'
|
||||
ic15_train_gt_dir = root_dir + 'ch4_training_localization_transcription_gt/'
|
||||
root_dir = config.TRAIN_ROOT_DIR
|
||||
ic15_train_data_dir = os.path.join(root_dir, 'ch4_training_images/')
|
||||
ic15_train_gt_dir = os.path.join(root_dir, 'ch4_training_localization_transcription_gt/')
|
||||
|
||||
self.img_size = self.img_size if \
|
||||
(self.img_size is None or isinstance(self.img_size, tuple)) \
|
||||
|
@ -276,7 +274,7 @@ class TrainDataset:
|
|||
|
||||
|
||||
def IC15_TEST_Generator():
|
||||
ic15_test_data_dir = config.TEST_ROOT_DIR + 'ch4_test_images/'
|
||||
ic15_test_data_dir = os.path.join(config.TEST_ROOT_DIR, 'ch4_test_images/')
|
||||
img_size = config.INFER_LONG_SIZE
|
||||
|
||||
img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size)
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
|
||||
global_yaml = '../../default_config.yaml'
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml
|
||||
|
||||
Args:
|
||||
parser: Parent parser
|
||||
cfg: Base configuration
|
||||
helper: Helper description
|
||||
cfg_path: Path to the default yaml config
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]',
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml')
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError('Failed to parse yaml')
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments
|
||||
|
||||
Args:
|
||||
args: command line arguments
|
||||
cfg: Base configuration
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='default name', add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, global_yaml),
|
||||
help='Config file path')
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id'
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return 'Local Job'
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
|
||||
_global_syn_count = 0
|
||||
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local
|
||||
Uploca data from local directory to remote obs in contrast
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_syn_count
|
||||
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count)
|
||||
_global_syn_count += 1
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print('from path: ', from_path)
|
||||
print('to path: ', to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print('===finished data synchronization===')
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print('===save flag===')
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
print('Finish sync data from {} to {}'.format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print('Dataset downloaded: ', os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
if not os.path.exists(config.load_path):
|
||||
# os.makedirs(config.load_path)
|
||||
print('=' * 20 + 'makedirs')
|
||||
if os.path.isdir(config.load_path):
|
||||
print('=' * 20 + 'makedirs success')
|
||||
else:
|
||||
print('=' * 20 + 'makedirs fail')
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print('Preload downloaded: ', os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print('Workspace downloaded: ', os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print('Start to copy output directory')
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -18,26 +18,22 @@ import os
|
|||
import math
|
||||
import operator
|
||||
from functools import reduce
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
import cv2
|
||||
from mindspore import Tensor, context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config
|
||||
from src.dataset import test_dataset_creator
|
||||
from src.ETSNET.etsnet import ETSNet
|
||||
from src.ETSNET.pse import pse
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
parser = argparse.ArgumentParser(description='Hyperparams')
|
||||
parser.add_argument("--ckpt", type=str, default=0, help='trained model path.')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False,
|
||||
save_graphs_path=".")
|
||||
|
||||
|
||||
class AverageMeter():
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
|
@ -55,12 +51,14 @@ class AverageMeter():
|
|||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def sort_to_clockwise(points):
|
||||
center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), points), [len(points)] * 2))
|
||||
clockwise_points = sorted(points, key=lambda coord: (-135 - math.degrees(
|
||||
math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True)
|
||||
return clockwise_points
|
||||
|
||||
|
||||
def write_result_as_txt(img_name, bboxes, path):
|
||||
if not os.path.isdir(path):
|
||||
os.makedirs(path)
|
||||
|
@ -76,17 +74,43 @@ def write_result_as_txt(img_name, bboxes, path):
|
|||
for line in lines:
|
||||
f.write(line)
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
local_path = '{}/{}'.format(config.modelarts_home, config.object_name)
|
||||
|
||||
os.system('cd {}&&tar -zxvf opencv-3.4.9.tar.gz'.format(local_path))
|
||||
|
||||
cmake_command = 'cmake -D CMAKE_BUILD_TYPE=Release -D CMAKE_INSTALL=/usr/local ..&&make -j16&&sudo make install'
|
||||
os.system('cd {}/opencv-3.4.9&&mkdir build&&cd ./build&&{}'.format(local_path, cmake_command))
|
||||
|
||||
os.system('cd {}/src/ETSNET/pse&&make clean&&make'.format(local_path))
|
||||
os.system('cd {}&&sed -i ’s/\r//‘ scripts/run_eval_ascend.sh')
|
||||
|
||||
|
||||
def modelarts_post_process():
|
||||
local_path = '{}/{}'.format(config.modelarts_home, config.object_name)
|
||||
os.system('cd {}&&sh scripts/run_eval_ascend.sh'.format(local_path))
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process, post_process=modelarts_post_process)
|
||||
def test():
|
||||
if not os.path.isdir('./res/submit_ic15/'):
|
||||
os.makedirs('./res/submit_ic15/')
|
||||
if not os.path.isdir('./res/vis_ic15/'):
|
||||
os.makedirs('./res/vis_ic15/')
|
||||
from src.ETSNET.pse import pse
|
||||
|
||||
local_path = ""
|
||||
if config.enable_modelarts:
|
||||
local_path = os.path.join(config.modelarts_home, config.object_name) + '/'
|
||||
print('local_path: ', local_path)
|
||||
|
||||
if not os.path.isdir('{}./res/submit_ic15/'.format(local_path)):
|
||||
os.makedirs('{}./res/submit_ic15/'.format(local_path))
|
||||
if not os.path.isdir('{}./res/vis_ic15/'.format(local_path)):
|
||||
os.makedirs('{}./res/vis_ic15/'.format(local_path))
|
||||
ds = test_dataset_creator()
|
||||
|
||||
config.INFERENCE = True
|
||||
net = ETSNet(config)
|
||||
print(args.ckpt)
|
||||
param_dict = load_checkpoint(args.ckpt)
|
||||
print(config.ckpt)
|
||||
param_dict = load_checkpoint(config.ckpt)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('parameters loaded!')
|
||||
|
||||
|
@ -149,8 +173,9 @@ def test():
|
|||
end_pts = time.time()
|
||||
|
||||
# save res
|
||||
cv2.imwrite('./res/vis_ic15/{}'.format(img_name), img[:, :, [2, 1, 0]].copy())
|
||||
write_result_as_txt(img_name, bboxes, './res/submit_ic15/')
|
||||
cv2.imwrite('{}./res/vis_ic15/{}'.format(local_path, img_name), img[:, :, [2, 1, 0]].copy())
|
||||
write_result_as_txt(img_name, bboxes, '{}./res/submit_ic15/'.format(local_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
|
|
@ -14,51 +14,55 @@
|
|||
# ============================================================================
|
||||
|
||||
|
||||
import argparse
|
||||
from ast import literal_eval as liter
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.dataset import train_dataset_creator
|
||||
from src.config import config
|
||||
from src.ETSNET.etsnet import ETSNet
|
||||
from src.ETSNET.dice_loss import DiceLoss
|
||||
from src.network_define import WithLossCell, TrainOneStepCell, LossCallBack
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
parser = argparse.ArgumentParser(description='Hyperparams')
|
||||
parser.add_argument('--run_distribute', default=False, action='store_true',
|
||||
help='Run distribute, default is false.')
|
||||
parser.add_argument('--pre_trained', type=str, default='', help='Pretrain file path.')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id, default is 0.')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1.')
|
||||
args = parser.parse_args()
|
||||
|
||||
set_seed(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train():
|
||||
rank_id = 0
|
||||
if args.run_distribute:
|
||||
context.set_auto_parallel_context(device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
config.BASE_LR = liter(config.BASE_LR)
|
||||
config.WARMUP_RATIO = liter(config.WARMUP_RATIO)
|
||||
|
||||
device_num = get_device_num()
|
||||
if config.run_distribute:
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
rank_id = get_rank_id()
|
||||
|
||||
# dataset/network/criterion/optim
|
||||
ds = train_dataset_creator(rank_id, args.device_num)
|
||||
ds = train_dataset_creator(rank_id, device_num)
|
||||
step_size = ds.get_dataset_size()
|
||||
print('Create dataset done!')
|
||||
|
||||
config.INFERENCE = False
|
||||
net = ETSNet(config)
|
||||
net = net.set_train()
|
||||
param_dict = load_checkpoint(args.pre_trained)
|
||||
param_dict = load_checkpoint(config.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('Load Pretrained parameters done!')
|
||||
|
||||
|
@ -69,20 +73,21 @@ def train():
|
|||
|
||||
# warp model
|
||||
net = WithLossCell(net, criterion)
|
||||
if args.run_distribute:
|
||||
net = TrainOneStepCell(net, opt, reduce_flag=True, mean=True, degree=args.device_num)
|
||||
if config.run_distribute:
|
||||
net = TrainOneStepCell(net, opt, reduce_flag=True, mean=True, degree=device_num)
|
||||
else:
|
||||
net = TrainOneStepCell(net, opt)
|
||||
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossCallBack(per_print_times=10)
|
||||
# set and apply parameters of check point config.TRAIN_MODEL_SAVE_PATH
|
||||
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
|
||||
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=3)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf,
|
||||
directory="./ckpt_{}".format(rank_id))
|
||||
directory="{}/ckpt_{}".format(config.TRAIN_MODEL_SAVE_PATH, rank_id))
|
||||
|
||||
model = Model(net)
|
||||
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=True, callbacks=[time_cb, loss_cb, ckpoint_cb])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
|
|
|
@ -87,7 +87,6 @@ MSCOCO2017
|
|||
├─run_infer_310.sh # Ascend推理shell脚本
|
||||
├─run_eval.sh # 使用Ascend环境运行推理脚本
|
||||
├─src
|
||||
├─config.py # 参数配置
|
||||
├─dataset.py # 数据预处理
|
||||
├─retinanet.py # 网络模型定义
|
||||
├─init_params.py # 参数初始化
|
||||
|
@ -95,17 +94,24 @@ MSCOCO2017
|
|||
├─coco_eval # coco数据集评估
|
||||
├─box_utils.py # 先验框设置
|
||||
├─_init_.py # 初始化
|
||||
├──model_utils
|
||||
├──config.py # 参数生成
|
||||
├──device_adapter.py # 设备相关信息
|
||||
├──local_adapter.py # 设备相关信息
|
||||
├──moxing_adapter.py # 装饰器(主要用于ModelArts数据拷贝)
|
||||
├─train.py # 网络训练脚本
|
||||
├─export.py # 导出 AIR,MINDIR模型的脚本
|
||||
├─postprogress.py # 310推理后处理脚本
|
||||
└─eval.py # 网络推理脚本
|
||||
└─create_data.py # 构建Mindrecord数据集脚本
|
||||
└─default_config.yaml # 参数配置
|
||||
|
||||
```
|
||||
|
||||
### [脚本参数](#content)
|
||||
|
||||
```python
|
||||
在train.py和config.py脚本中使用到的主要参数是:
|
||||
```default_config.yaml
|
||||
在脚本中使用到的主要参数是:
|
||||
"img_shape": [600, 600], # 图像尺寸
|
||||
"num_retinanet_boxes": 67995, # 设置的先验框总数
|
||||
"match_thershold": 0.5, # 匹配阈值
|
||||
|
@ -125,10 +131,10 @@ MSCOCO2017
|
|||
"num_default": [9, 9, 9, 9, 9], # 单个网格中先验框的个数
|
||||
"extras_out_channels": [256, 256, 256, 256, 256], # 特征层输出通道数
|
||||
"feature_size": [75, 38, 19, 10, 5], # 特征层尺寸
|
||||
"aspect_ratios": [(0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0), (0.5,1.0,2.0)], # 先验框大小变化比值
|
||||
"steps": ( 8, 16, 32, 64, 128), # 先验框设置步长
|
||||
"anchor_size":(32, 64, 128, 256, 512), # 先验框尺寸
|
||||
"prior_scaling": (0.1, 0.2), # 用于调节回归与回归在loss中占的比值
|
||||
"aspect_ratios": [[0.5,1.0,2.0], [0.5,1.0,2.0], [0.5,1.0,2.0], [0.5,1.0,2.0], [0.5,1.0,2.0]], # 先验框大小变化比值
|
||||
"steps": [8, 16, 32, 64, 128], # 先验框设置步长
|
||||
"anchor_size":[32, 64, 128, 256, 512], # 先验框尺寸
|
||||
"prior_scaling": [0.1, 0.2], # 用于调节回归与回归在loss中占的比值
|
||||
"gamma": 2.0, # focal loss中的参数
|
||||
"alpha": 0.75, # focal loss中的参数
|
||||
"mindrecord_dir": "/cache/MindRecord_COCO", # mindrecord文件路径
|
||||
|
@ -159,7 +165,7 @@ MSCOCO2017
|
|||
"save_checkpoint": True, # 保存checkpoint
|
||||
"save_checkpoint_epochs": 1, # 保存checkpoint epoch数
|
||||
"keep_checkpoint_max":1, # 保存checkpoint的最大数量
|
||||
"save_checkpoint_path": "./model", # 保存checkpoint的路径
|
||||
"save_checkpoint_path": "./ckpt", # 保存checkpoint的路径
|
||||
"finish_epoch":0, # 已经运行完成的 epoch 数
|
||||
"checkpoint_path":"/home/hitwh1/1.0/ckpt_0/retinanet-500_458_59.ckpt" # 用于验证的checkpoint路径
|
||||
```
|
||||
|
@ -174,11 +180,11 @@ MSCOCO2017
|
|||
# 八卡并行训练示例:
|
||||
|
||||
创建 RANK_TABLE_FILE
|
||||
sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
|
||||
sh scripts/run_distribute_train.sh DEVICE_NUM RANK_TABLE_FILE MINDRECORD_DIR PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
|
||||
|
||||
# 单卡训练示例:
|
||||
|
||||
sh run_single_train.sh DEVICE_ID EPOCH_SIZE LR PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
|
||||
sh scripts/run_single_train.sh DEVICE_ID MINDRECORD_DIR PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
|
||||
|
||||
```
|
||||
|
||||
|
@ -189,22 +195,22 @@ sh run_single_train.sh DEVICE_ID EPOCH_SIZE LR PRE_TRAINED(optional) PRE_TRAINED
|
|||
#### 运行
|
||||
|
||||
```运行
|
||||
训练前,先创建MindRecord文件,以COCO数据集为例
|
||||
训练前,先创建MindRecord文件,以COCO数据集为例,yaml文件配置好coco数据集路径和mindrecord存储路径
|
||||
python create_data.py --dataset coco
|
||||
|
||||
Ascend:
|
||||
# 八卡并行训练示例(在retinanet目录下运行):
|
||||
sh scripts/run_distribute_train.sh 8 500 0.09 RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址,可选) PRE_TRAINED_EPOCH_SIZE(预训练EPOCH大小,可选)
|
||||
sh scripts/run_distribute_train.sh 8 RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) MINDRECORD_DIR(mindrecord数据集文件夹路径) PRE_TRAINED(预训练checkpoint地址,可选) PRE_TRAINED_EPOCH_SIZE(预训练EPOCH大小,可选)
|
||||
|
||||
例如:sh scripts/run_distribute_train.sh 8 500 0.09 scripts/rank_table_8pcs.json
|
||||
例如:sh scripts/run_distribute_train.sh 8 scripts/rank_table_8pcs.json ./cache/mindrecord_coco
|
||||
|
||||
# 单卡训练示例(在retinanet目录下运行):
|
||||
sh scripts/run_single_train.sh 0 500 0.09
|
||||
sh scripts/run_single_train.sh 0 ./cache/mindrecord_coco
|
||||
```
|
||||
|
||||
#### 结果
|
||||
|
||||
训练结果将存储在示例路径中。checkpoint将存储在 `./model` 路径下,训练日志将被记录到 `./log.txt` 中,训练日志部分示例如下:
|
||||
训练结果将存储在示例路径中。checkpoint将存储在 `./ckpt` 路径下,训练日志将被记录到 `./log.txt` 中,训练日志部分示例如下:
|
||||
|
||||
```训练日志
|
||||
epoch: 2 step: 458, loss is 120.56251
|
||||
|
@ -221,6 +227,60 @@ lr:[0.000064]
|
|||
Epoch time: 164531.610, per step time: 359.239
|
||||
```
|
||||
|
||||
- 如果要在modelarts上进行模型的训练,可以参考modelarts的[官方指导文档](https://support.huaweicloud.com/modelarts/) 开始进行模型的训练和推理,具体操作如下:
|
||||
|
||||
```ModelArts
|
||||
# 在ModelArts上使用分布式训练示例:
|
||||
# 数据集存放方式
|
||||
|
||||
# ├── MindRecord_COCO # dir
|
||||
# ├── annotations # annotations dir
|
||||
# ├── instances_val2017.json # annotations file
|
||||
# ├── checkpoint # checkpoint dir
|
||||
# ├── pred_train # predtrained dir
|
||||
# ├── MindRecord_COCO.zip # train mindrecord file and eval mindrecord file
|
||||
|
||||
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
|
||||
# a. 设置 "enable_modelarts=True"
|
||||
# 设置 "distribute=True"
|
||||
# 设置 "keep_checkpoint_max=5"
|
||||
# 设置 "save_checkpoint_path=/cache/train/checkpoint"
|
||||
# 设置 "mindrecord_dir=/cache/data/MindRecord_COCO"
|
||||
# 设置 "epoch_size=550"
|
||||
# 设置 "modelarts_dataset_unzip_name=MindRecord_COCO"
|
||||
# 设置 "pre_trained=/cache/data/train/train_predtrained/pred file name" 如果没有预训练权重 pre_trained=""
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
|
||||
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) 在modelarts的界面上设置代码的路径 "/path/retinanet"。
|
||||
# (4) 在modelarts的界面上设置模型的启动文件 "train.py" 。
|
||||
# (5) 在modelarts的界面上设置模型的数据路径 ".../MindRecord_COCO"(选择MindRecord_COCO文件夹路径) ,
|
||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
# (6) 开始模型的训练。
|
||||
|
||||
# 在modelarts上使用模型推理的示例
|
||||
# (1) 把训练好的模型地方到桶的对应位置。
|
||||
# (2) 选择a或者b其中一种方式。
|
||||
# a.设置 "enable_modelarts=True"
|
||||
# 设置 "mindrecord_dir=/cache/data/MindRecord_COCO"
|
||||
# 设置 "checkpoint_path=/cache/data/checkpoint/checkpoint file name"
|
||||
# 设置 "instance_set=/cache/data/MindRecord_COCO/annotations/instances_{}.json"
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
|
||||
# (3) 设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (4) 在modelarts的界面上设置代码的路径 "/path/retinanet"。
|
||||
# (5) 在modelarts的界面上设置模型的启动文件 "eval.py" 。
|
||||
# (6) 在modelarts的界面上设置模型的数据路径 "../MindRecord_COCO"(选择MindRecord_COCO文件夹路径) ,
|
||||
# 模型的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
# (7) 开始模型的推理。
|
||||
```
|
||||
|
||||
### [评估过程](#content)
|
||||
|
||||
#### <span id="usage">用法</span>
|
||||
|
@ -228,13 +288,13 @@ Epoch time: 164531.610, per step time: 359.239
|
|||
使用shell脚本进行评估。shell脚本的用法如下:
|
||||
|
||||
```eval
|
||||
sh scripts/run_eval.sh [DATASET] [DEVICE_ID]
|
||||
sh scripts/run_eval.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [CHECKPOINT_PATH] [ANN_FILE PATH]
|
||||
```
|
||||
|
||||
#### <span id="running">运行</span>
|
||||
|
||||
```eval运行
|
||||
sh scripts/run_eval.sh coco 0
|
||||
sh scripts/run_eval.sh 0 coco /cache/mindrecord_dir/ /cache/checkpoint/retinanet_500-458.ckpt /cache/anno_path/instances_{}.json
|
||||
```
|
||||
|
||||
> checkpoint 可以在训练过程中产生.
|
||||
|
@ -269,7 +329,7 @@ mAP: 0.34747137754625645
|
|||
导出模型前要修改config.py文件中的checkpoint_path配置项,值为checkpoint的路径。
|
||||
|
||||
```shell
|
||||
python export.py --run_platform [RUN_PLATFORM] --file_format[EXPORT_FORMAT]
|
||||
python export.py --file_name [RUN_PLATFORM] --file_format[EXPORT_FORMAT] --checkpoint_path [CHECKPOINT PATH]
|
||||
```
|
||||
|
||||
`EXPORT_FORMAT` 可选 ["AIR", "MINDIR"]
|
||||
|
@ -277,7 +337,27 @@ python export.py --run_platform [RUN_PLATFORM] --file_format[EXPORT_FORMAT]
|
|||
#### <span id="running">运行</span>
|
||||
|
||||
```运行
|
||||
python export.py --run_platform ascend --file_format MINDIR
|
||||
python export.py --file_name retinanet --file_format MINDIR --checkpoint_path /cache/checkpoint/retinanet_550-458.ckpt
|
||||
```
|
||||
|
||||
- 在modelarts上导出MindIR
|
||||
|
||||
```Modelarts
|
||||
在ModelArts上导出MindIR示例
|
||||
# (1) 选择a(修改yaml文件参数)或者b(ModelArts创建训练作业修改参数)其中一种方式。
|
||||
# a. 设置 "enable_modelarts=True"
|
||||
# 设置 "file_name=/cache/train/cnnctc"
|
||||
# 设置 "file_format=MINDIR"
|
||||
# 设置 "checkpoint_path=/cache/data/checkpoint/checkpoint file name"
|
||||
|
||||
# b. 增加 "enable_modelarts=True" 参数在modearts的界面上。
|
||||
# 在modelarts的界面上设置方法a所需要的参数
|
||||
# 注意:路径参数不需要加引号
|
||||
# (2)设置网络配置文件的路径 "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) 在modelarts的界面上设置代码的路径 "/path/retinanet"。
|
||||
# (4) 在modelarts的界面上设置模型的启动文件 "export.py" 。
|
||||
# (5) 在modelarts的界面上设置模型的数据路径 ".../MindRecord_COCO"(选择MindRecord_COCO文件夹路径) ,
|
||||
# MindIR的输出路径"Output file path" 和模型的日志路径 "Job log path" 。
|
||||
```
|
||||
|
||||
### [推理过程](#content)
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
need_modelarts_dataset_unzip: True
|
||||
modelarts_dataset_unzip_name: "MindRecord_COCO"
|
||||
|
||||
# ======================================================================================
|
||||
# common options
|
||||
distribute: False
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
img_shape: [600, 600]
|
||||
num_retinanet_boxes: 67995
|
||||
match_thershold: 0.5
|
||||
nms_thershold: 0.6
|
||||
min_score: 0.1
|
||||
max_boxes: 100
|
||||
|
||||
# learning rate settings
|
||||
lr: 0.1
|
||||
global_step: 0
|
||||
lr_init: 1e-6
|
||||
lr_end_rate: 5e-3
|
||||
warmup_epochs1: 2
|
||||
warmup_epochs2: 5
|
||||
warmup_epochs3: 23
|
||||
warmup_epochs4: 60
|
||||
warmup_epochs5: 160
|
||||
momentum: 0.9
|
||||
weight_decay: 1.5e-4
|
||||
|
||||
# network
|
||||
num_default: [9, 9, 9, 9, 9]
|
||||
extras_out_channels: [256, 256, 256, 256, 256]
|
||||
feature_size: [75, 38, 19, 10, 5]
|
||||
aspect_ratios: [[0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0]]
|
||||
steps: [8, 16, 32, 64, 128]
|
||||
anchor_size: [32, 64, 128, 256, 512]
|
||||
prior_scaling: [0.1, 0.2]
|
||||
gamma: 2.0
|
||||
alpha: 0.75
|
||||
num_classes: 81
|
||||
|
||||
# `mindrecord_dir` and `coco_root` are better to use absolute path.
|
||||
mindrecord_dir: "./"
|
||||
coco_root: "./"
|
||||
train_data_type: "train2017"
|
||||
val_data_type: "val2017"
|
||||
instances_set: "./instances_{}.json"
|
||||
coco_classes: ["background", "person", "bicycle", "car", "motorcycle", "airplane", "bus",
|
||||
"train", "truck", "boat", "traffic light", "fire hydrant",
|
||||
"stop sign", "parking meter", "bench", "bird", "cat", "dog",
|
||||
"horse", "sheep", "cow", "elephant", "bear", "zebra",
|
||||
"giraffe", "backpack", "umbrella", "handbag", "tie",
|
||||
"suitcase", "frisbee", "skis", "snowboard", "sports ball",
|
||||
"kite", "baseball bat", "baseball glove", "skateboard",
|
||||
"surfboard", "tennis racket", "bottle", "wine glass", "cup",
|
||||
"fork", "knife", "spoon", "bowl", "banana", "apple",
|
||||
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
|
||||
"donut", "cake", "chair", "couch", "potted plant", "bed",
|
||||
"dining table", "toilet", "tv", "laptop", "mouse", "remote",
|
||||
"keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
|
||||
"refrigerator", "book", "clock", "vase", "scissors",
|
||||
"teddy bear", "hair drier", "toothbrush"]
|
||||
|
||||
|
||||
# The annotation.json position of voc validation dataset
|
||||
voc_root: ""
|
||||
|
||||
# voc original dataset
|
||||
voc_dir: ""
|
||||
|
||||
# if coco or voc used, `image_dir` and `anno_path` are useless
|
||||
image_dir: ""
|
||||
anno_path: ""
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 1
|
||||
keep_checkpoint_max: 10
|
||||
save_checkpoint_path: "./ckpt"
|
||||
finish_epoch: 0
|
||||
|
||||
# optimiter options
|
||||
workers: 24
|
||||
mode: "sink"
|
||||
epoch_size: 550
|
||||
batch_size: 32
|
||||
pre_trained: ""
|
||||
pre_trained_epoch_size: 0
|
||||
loss_scale: 1024
|
||||
filter_weight: False
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
dataset: "coco"
|
||||
checkpoint_path: ""
|
||||
|
||||
# ======================================================================================
|
||||
# export options
|
||||
device_id: 0
|
||||
file_format: "MINDIR"
|
||||
export_batch_size: 1
|
||||
file_name: "retinanet"
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of input data"
|
||||
output_pah: "The location of the output file"
|
||||
device_target: "device id of GPU or Ascend. (Default: None)"
|
||||
enable_profiling: "Whether enable profiling while training default: False"
|
||||
workers: "Num parallel workers."
|
||||
lr: "Learning rate, default is 0.1."
|
||||
mode: "Run sink mode or not, default is sink."
|
||||
epoch_size: "Epoch size, default is 500."
|
||||
batch_size: "Batch size, default is 32."
|
||||
pre_trained: "Pretrained Checkpoint file path."
|
||||
pre_trained_epoch_size: "Pretrained epoch size."
|
||||
save_checkpoint_epochs: "Save checkpoint epochs, default is 1."
|
||||
loss_scale: "Loss scale, default is 1024."
|
||||
filter_weight: "Filter weight parameters, default is False."
|
||||
dataset: "Dataset, default is coco."
|
||||
device_id: "Device id, default is 0."
|
||||
file_format: "file format choices [AIR, MINDIR]"
|
||||
file_name: "output file name."
|
||||
export_batch_size: "batch size"
|
|
@ -16,85 +16,70 @@
|
|||
"""Evaluation for retinanet"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import json
|
||||
import numpy as np
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.retinanet import retinanet50, resnet50, retinanetInferWithDecoder
|
||||
from src.retinanet import retinanet50, resnet50, retinanetInferWithDecoder
|
||||
from src.dataset import create_retinanet_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
|
||||
from src.config import config
|
||||
from src.coco_eval import metrics
|
||||
from src.box_utils import default_boxes
|
||||
|
||||
def retinanet_eval(dataset_path, ckpt_path):
|
||||
"""retinanet evaluation."""
|
||||
batch_size = 1
|
||||
ds = create_retinanet_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False)
|
||||
backbone = resnet50(config.num_classes)
|
||||
net = retinanet50(backbone, config)
|
||||
net = retinanetInferWithDecoder(net, Tensor(default_boxes), config)
|
||||
print("Load Checkpoint!")
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
net.init_parameters_data()
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
net.set_train(False)
|
||||
i = batch_size
|
||||
total = ds.get_dataset_size() * batch_size
|
||||
start = time.time()
|
||||
pred_data = []
|
||||
print("\n========================================\n")
|
||||
print("total images num: ", total)
|
||||
print("Processing, please wait a moment.")
|
||||
for data in ds.create_dict_iterator(output_numpy=True):
|
||||
img_id = data['img_id']
|
||||
img_np = data['image']
|
||||
image_shape = data['image_shape']
|
||||
|
||||
output = net(Tensor(img_np))
|
||||
for batch_idx in range(img_np.shape[0]):
|
||||
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
|
||||
"box_scores": output[1].asnumpy()[batch_idx],
|
||||
"img_id": int(np.squeeze(img_id[batch_idx])),
|
||||
"image_shape": image_shape[batch_idx]})
|
||||
percent = round(i / total * 100., 2)
|
||||
|
||||
print(f' {str(percent)} [{i}/{total}]', end='\r')
|
||||
i += batch_size
|
||||
cost_time = int((time.time() - start) * 1000)
|
||||
print(f' 100% [{total}/{total}] cost {cost_time} ms')
|
||||
mAP = metrics(pred_data)
|
||||
print("\n========================================\n")
|
||||
print(f"mAP: {mAP}")
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='retinanet evaluation')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
|
||||
help="run platform, only support Ascend.")
|
||||
args_opt = parser.parse_args()
|
||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||
"""Apply NMS to bboxes."""
|
||||
y1 = all_boxes[:, 0]
|
||||
x1 = all_boxes[:, 1]
|
||||
y2 = all_boxes[:, 2]
|
||||
x2 = all_boxes[:, 3]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
|
||||
order = all_scores.argsort()[::-1]
|
||||
keep = []
|
||||
|
||||
prefix = "retinanet_eval.mindrecord"
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
if args_opt.dataset == "voc":
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
if len(keep) >= max_boxes:
|
||||
break
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def make_dataset_dir(mindrecord_dir, mindrecord_file, prefix):
|
||||
if config.dataset == "voc":
|
||||
config.coco_root = config.voc_root
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if args_opt.dataset == "coco":
|
||||
if config.dataset == "coco":
|
||||
if os.path.isdir(config.coco_root):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("coco", False, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("coco_root not exits.")
|
||||
elif args_opt.dataset == "voc":
|
||||
elif config.dataset == "voc":
|
||||
if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root):
|
||||
print("Create Mindrecord.")
|
||||
voc_data_to_mindrecord(mindrecord_dir, False, prefix)
|
||||
|
@ -108,6 +93,163 @@ if __name__ == '__main__':
|
|||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("IMAGE_DIR or ANNO_PATH not exits.")
|
||||
|
||||
print("Start Eval!")
|
||||
retinanet_eval(mindrecord_file, config.checkpoint_path)
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def retinanet_eval():
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
|
||||
prefix = "retinanet_eval.mindrecord"
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
make_dataset_dir(mindrecord_dir, mindrecord_file, prefix)
|
||||
|
||||
batch_size = 1
|
||||
ds = create_retinanet_dataset(mindrecord_file, batch_size=batch_size, repeat_num=1, is_training=False)
|
||||
backbone = resnet50(config.num_classes)
|
||||
net = retinanet50(backbone, config)
|
||||
net = retinanetInferWithDecoder(net, Tensor(default_boxes), config)
|
||||
print("Load Checkpoint!")
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
net.init_parameters_data()
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
net.set_train(False)
|
||||
i = batch_size
|
||||
total = ds.get_dataset_size() * batch_size
|
||||
start = time.time()
|
||||
predictions = []
|
||||
img_ids = []
|
||||
print("\n========================================\n")
|
||||
print("total images num: ", total)
|
||||
print("Processing, please wait a moment.")
|
||||
num_classes = config.num_classes
|
||||
coco_root = config.coco_root
|
||||
data_type = config.val_data_type
|
||||
#Classes need to train or test.
|
||||
val_cls = config.coco_classes
|
||||
val_cls_dict = {}
|
||||
for i, cls in enumerate(val_cls):
|
||||
val_cls_dict[i] = cls
|
||||
anno_json = os.path.join(coco_root, config.instances_set.format(data_type))
|
||||
coco_gt = COCO(anno_json)
|
||||
classs_dict = {}
|
||||
cat_ids = coco_gt.loadCats(coco_gt.getCatIds())
|
||||
for cat in cat_ids:
|
||||
classs_dict[cat["name"]] = cat["id"]
|
||||
|
||||
for data in ds.create_dict_iterator(output_numpy=True):
|
||||
pred_data = []
|
||||
img_id = data['img_id']
|
||||
img_np = data['image']
|
||||
image_shape = data['image_shape']
|
||||
|
||||
output = net(Tensor(img_np))
|
||||
for batch_idx in range(img_np.shape[0]):
|
||||
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
|
||||
"box_scores": output[1].asnumpy()[batch_idx],
|
||||
"img_id": int(np.squeeze(img_id[batch_idx])),
|
||||
"image_shape": image_shape[batch_idx]})
|
||||
i += batch_size
|
||||
for sample in pred_data:
|
||||
pred_boxes = sample['boxes']
|
||||
box_scores = sample['box_scores']
|
||||
img_id = sample['img_id']
|
||||
h, w = sample['image_shape']
|
||||
|
||||
final_boxes = []
|
||||
final_label = []
|
||||
final_score = []
|
||||
img_ids.append(img_id)
|
||||
|
||||
for c in range(1, num_classes):
|
||||
class_box_scores = box_scores[:, c]
|
||||
score_mask = class_box_scores > config.min_score
|
||||
class_box_scores = class_box_scores[score_mask]
|
||||
class_boxes = pred_boxes[score_mask] * [h, w, h, w]
|
||||
|
||||
if score_mask.any():
|
||||
nms_index = apply_nms(class_boxes, class_box_scores, config.nms_thershold, config.max_boxes)
|
||||
class_boxes = class_boxes[nms_index]
|
||||
class_box_scores = class_box_scores[nms_index]
|
||||
final_boxes += class_boxes.tolist()
|
||||
final_score += class_box_scores.tolist()
|
||||
final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores)
|
||||
for loc, label, score in zip(final_boxes, final_label, final_score):
|
||||
res = {}
|
||||
res['image_id'] = img_id
|
||||
res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]]
|
||||
res['score'] = score
|
||||
res['category_id'] = label
|
||||
predictions.append(res)
|
||||
with open('predictions.json', 'w') as f:
|
||||
json.dump(predictions, f)
|
||||
|
||||
cost_time = int((time.time() - start) * 1000)
|
||||
print(f' 100% [{total}/{total}] cost {cost_time} ms')
|
||||
coco_dt = coco_gt.loadRes('predictions.json')
|
||||
E = COCOeval(coco_gt, coco_dt, iouType='bbox')
|
||||
E.params.imgIds = img_ids
|
||||
E.evaluate()
|
||||
E.accumulate()
|
||||
E.summarize()
|
||||
mAP = E.stats[0]
|
||||
print("\n========================================\n")
|
||||
print(f"mAP: {mAP}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
retinanet_eval()
|
||||
|
|
|
@ -13,26 +13,23 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export for retinanet"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from src.retinanet import retinanet50, resnet50, retinanetInferWithDecoder
|
||||
from src.config import config
|
||||
from src.retinanet import retinanet50, resnet50, retinanetInferWithDecoder
|
||||
from src.model_utils.config import config
|
||||
from src.box_utils import default_boxes
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='retinanet evaluation')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
|
||||
help="run platform, only support Ascend.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--file_name", type=str, default="retinanet", help="output file name.")
|
||||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def model_export():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=config.device_id)
|
||||
|
||||
backbone = resnet50(config.num_classes)
|
||||
net = retinanet50(backbone, config)
|
||||
|
@ -41,6 +38,10 @@ if __name__ == '__main__':
|
|||
net.init_parameters_data()
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
shape = [args_opt.batch_size, 3] + config.img_shape
|
||||
shape = [config.export_batch_size, 3] + config.img_shape
|
||||
input_data = Tensor(np.zeros(shape), mstype.float32)
|
||||
export(net, input_data, file_name=args_opt.file_name, file_format=args_opt.file_format)
|
||||
export(net, input_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_export()
|
||||
|
|
|
@ -16,15 +16,15 @@
|
|||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
|
||||
echo "for example: sh run_distribute_train.sh 8 500 0.1 /data/hccl.json /opt/retinanet-500_458.ckpt(optional) 200(optional)"
|
||||
echo "sh scripts/run_distribute_train.sh DEVICE_NUM RANK_TABLE_FILE MINDRECORD_DIR PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
|
||||
echo "for example: sh scripts/run_distribute_train.sh 8 /data/hccl.json /cache/mindrecord_dir/ /opt/retinanet-500_458.ckpt(optional) 200(optional)"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
if [ $# != 4 ] && [ $# != 6 ]
|
||||
if [ $# != 3 ] && [ $# != 5 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] \
|
||||
[RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
echo "Usage: sh scripts/run_distribute_train.sh [DEVICE_NUM] [RANK_TABLE_FILE] \
|
||||
[MINDRECORD_DIR] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -34,11 +34,10 @@ process_cores=$(($core_num/8))
|
|||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
LR=$3
|
||||
PRE_TRAINED=$5
|
||||
PRE_TRAINED_EPOCH_SIZE=$6
|
||||
export RANK_TABLE_FILE=$4
|
||||
MINDRECORD_DIR=$3
|
||||
PRE_TRAINED=$4
|
||||
PRE_TRAINED_EPOCH_SIZE=$5
|
||||
export RANK_TABLE_FILE=$2
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
|
@ -48,6 +47,7 @@ do
|
|||
cp ./*.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cp -r ./scripts ./LOG$i
|
||||
cp ./*yaml ./LOG$i
|
||||
start=`expr $i \* $process_cores`
|
||||
end=`expr $start \+ $(($process_cores-1))`
|
||||
cmdopt=$start"-"$end
|
||||
|
@ -55,28 +55,22 @@ do
|
|||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 4 ]
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
taskset -c $cmdopt python train.py \
|
||||
--workers=$process_cores \
|
||||
--distribute=True \
|
||||
--lr=$LR \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
--mindrecord_dir=$MINDRECORD_DIR > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
if [ $# == 6 ]
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
taskset -c $cmdopt python train.py \
|
||||
--workers=$process_cores \
|
||||
--distribute=True \
|
||||
--lr=$LR \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--mindrecord_dir=$MINDRECORD_DIR \
|
||||
--pre_trained=$PRE_TRAINED \
|
||||
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
cd ../
|
||||
|
|
|
@ -14,18 +14,20 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
if [ $# != 5 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DATASET] [DEVICE_ID]"
|
||||
echo "Usage: sh scripts/run_eval.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [checkpoint_path] [instances_set]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATASET=$1
|
||||
DATASET=$2
|
||||
MINDRECORD_DIR=$3
|
||||
CHECKPOINT_PATH=$4
|
||||
INSTANCE_SET=$5
|
||||
echo $DATASET
|
||||
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=$2
|
||||
export DEVICE_ID=$1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
|
@ -40,10 +42,13 @@ fi
|
|||
mkdir ./eval$2
|
||||
cp ./*.py ./eval$2
|
||||
cp -r ./src ./eval$2
|
||||
cp ./*yaml ./eval$2
|
||||
cd ./eval$2 || exit
|
||||
env > env.log
|
||||
echo "start inferring for device $DEVICE_ID"
|
||||
python eval.py \
|
||||
--dataset=$DATASET \
|
||||
--device_id=$2 > log.txt 2>&1 &
|
||||
--checkpoint_path=$CHECKPOINT_PATH \
|
||||
--instances_set=$INSTANCE_SET \
|
||||
--mindrecord_dir=$MINDRECORD_DIR > log.txt 2>&1 &
|
||||
cd ..
|
||||
|
|
|
@ -16,57 +16,52 @@
|
|||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_single_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
|
||||
echo "for example: sh run_single_train.sh 0 500 0.1 /opt/retinanet-500_458.ckpt(optional) 200(optional)"
|
||||
echo "sh scripts/run_single_train.sh DEVICE_ID MINDRECORD_DIR PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
|
||||
echo "for example: sh scripts/run_single_train.sh 0 /cache/mindrecord_dir/ /opt/retinanet-500_458.ckpt(optional) 200(optional)"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
if [ $# != 3 ] && [ $# != 5 ]
|
||||
if [ $# != 2 ] && [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_single_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] \
|
||||
echo "Usage: sh scripts/run_single_train.sh [DEVICE_ID] [MINDRECORD_DIR] \
|
||||
[PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Before start single train, first create mindrecord files.
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
# BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
# cd $BASE_PATH/../ || exit
|
||||
# python train.py --only_create_dataset=True
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
LR=$3
|
||||
PRE_TRAINED=$4
|
||||
PRE_TRAINED_EPOCH_SIZE=$5
|
||||
MINDRECORD_DIR=$2
|
||||
PRE_TRAINED=$3
|
||||
PRE_TRAINED_EPOCH_SIZE=$4
|
||||
|
||||
rm -rf LOG$1
|
||||
mkdir ./LOG$1
|
||||
cp ./*.py ./LOG$1
|
||||
cp -r ./src ./LOG$1
|
||||
cp ./*yaml ./LOG$1
|
||||
cd ./LOG$1 || exit
|
||||
echo "start training for device $1"
|
||||
env > env.log
|
||||
if [ $# == 3 ]
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
python train.py \
|
||||
--distribute=False \
|
||||
--lr=$LR \
|
||||
--device_num=1 \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
--mindrecord_dir=$MINDRECORD_DIR > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
if [ $# == 5 ]
|
||||
if [ $# == 4 ]
|
||||
then
|
||||
python train,py \
|
||||
--distribute=False \
|
||||
--lr=$LR \
|
||||
--device_num=1 \
|
||||
--device_id=$DEVICE_ID \
|
||||
--mindrecord_dir=$MINDRECORD_DIR \
|
||||
--pre_trained=$PRE_TRAINED \
|
||||
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
cd ../
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
import math
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
from .config import config
|
||||
from src.model_utils.config import config
|
||||
|
||||
|
||||
class GeneratDefaultBoxes():
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from .config import config
|
||||
from src.model_utils.config import config
|
||||
|
||||
|
||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||
|
|
|
@ -1,86 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
|
||||
"""Config parameters for retinanet models."""
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"img_shape": [600, 600],
|
||||
"num_retinanet_boxes": 67995,
|
||||
"match_thershold": 0.5,
|
||||
"nms_thershold": 0.6,
|
||||
"min_score": 0.1,
|
||||
"max_boxes": 100,
|
||||
|
||||
# learing rate settings
|
||||
"global_step": 0,
|
||||
"lr_init": 1e-6,
|
||||
"lr_end_rate": 5e-3,
|
||||
"warmup_epochs1": 2,
|
||||
"warmup_epochs2": 5,
|
||||
"warmup_epochs3": 23,
|
||||
"warmup_epochs4": 60,
|
||||
"warmup_epochs5": 160,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1.5e-4,
|
||||
|
||||
# network
|
||||
"num_default": [9, 9, 9, 9, 9],
|
||||
"extras_out_channels": [256, 256, 256, 256, 256],
|
||||
"feature_size": [75, 38, 19, 10, 5],
|
||||
"aspect_ratios": [(0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0)],
|
||||
"steps": (8, 16, 32, 64, 128),
|
||||
"anchor_size": (32, 64, 128, 256, 512),
|
||||
"prior_scaling": (0.1, 0.2),
|
||||
"gamma": 2.0,
|
||||
"alpha": 0.75,
|
||||
|
||||
# `mindrecord_dir` and `coco_root` are better to use absolute path.
|
||||
"mindrecord_dir": "/data/hitwh/retinanet/MindRecord_COCO",
|
||||
"coco_root": "/data/dataset/coco2017",
|
||||
"train_data_type": "train2017",
|
||||
"val_data_type": "val2017",
|
||||
"instances_set": "/data/dataset/coco2017/annotations/instances_{}.json",
|
||||
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush'),
|
||||
"num_classes": 81,
|
||||
# The annotation.json position of voc validation dataset.
|
||||
"voc_root": "",
|
||||
# voc original dataset.
|
||||
"voc_dir": "",
|
||||
# if coco or voc used, `image_dir` and `anno_path` are useless.
|
||||
"image_dir": "",
|
||||
"anno_path": "",
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 1,
|
||||
"save_checkpoint_path": "./model",
|
||||
"finish_epoch": 0,
|
||||
"checkpoint_path": "/home/hitwh1/1.0/ckpt_0/retinanet-500_458_59.ckpt"
|
||||
})
|
|
@ -22,11 +22,10 @@ import json
|
|||
import xml.etree.ElementTree as et
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from .config import config
|
||||
from src.model_utils.config import config
|
||||
from .box_utils import jaccard_numpy, retinanet_bboxes_encode
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
|
||||
_config_path = '../../default_config.yaml'
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml
|
||||
|
||||
Args:
|
||||
parser: Parent parser
|
||||
cfg: Base configuration
|
||||
helper: Helper description
|
||||
cfg_path: Path to the default yaml config
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]',
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml')
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError('Failed to parse yaml')
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments
|
||||
|
||||
Args:
|
||||
args: command line arguments
|
||||
cfg: Base configuration
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='default name', add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, _config_path),
|
||||
help='Config file path')
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id'
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return 'Local Job'
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
|
||||
_global_syn_count = 0
|
||||
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local
|
||||
Uploca data from local directory to remote obs in contrast
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_syn_count
|
||||
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count)
|
||||
_global_syn_count += 1
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print('from path: ', from_path)
|
||||
print('to path: ', to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print('===finished data synchronization===')
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print('===save flag===')
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
print('Finish sync data from {} to {}'.format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print('Dataset downloaded: ', os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
if not os.path.exists(config.load_path):
|
||||
# os.makedirs(config.load_path)
|
||||
print('=' * 20 + 'makedirs')
|
||||
if os.path.isdir(config.load_path):
|
||||
print('=' * 20 + 'makedirs success')
|
||||
else:
|
||||
print('=' * 20 + 'makedirs fail')
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print('Preload downloaded: ', os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print('Workspace downloaded: ', os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print('Start to copy output directory')
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -16,24 +16,28 @@
|
|||
"""Train retinanet and get checkpoint files."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import time
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor, Callback
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from src.retinanet import retinanetWithLossCell, TrainingWrapper, retinanet50, resnet50
|
||||
from src.config import config
|
||||
from src.retinanet import retinanetWithLossCell, TrainingWrapper, retinanet50, resnet50
|
||||
from src.dataset import create_retinanet_dataset
|
||||
from src.lr_schedule import get_lr
|
||||
from src.init_params import init_net_param, filter_checkpoint_parameter
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
@ -52,81 +56,118 @@ class Monitor(Callback):
|
|||
super(Monitor, self).__init__()
|
||||
self.lr_init = lr_init
|
||||
self.lr_init_len = len(lr_init)
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
print("lr:[{:8.6f}]".format(self.lr_init[cb_params.cur_step_num-1]), flush=True)
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="retinanet training")
|
||||
|
||||
parser.add_argument("--distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is False.")
|
||||
parser.add_argument("--workers", type=int, default=24, help="Num parallel workers.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.")
|
||||
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
|
||||
parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
||||
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
|
||||
parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 1.")
|
||||
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
|
||||
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
|
||||
help="Filter weight parameters, default is False.")
|
||||
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
|
||||
help="run platform, only support Ascend.")
|
||||
args_opt = parser.parse_args()
|
||||
config.lr_init = ast.literal_eval(config.lr_init)
|
||||
config.lr_end_rate = ast.literal_eval(config.lr_end_rate)
|
||||
|
||||
if args_opt.run_platform == "Ascend":
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if args_opt.distribute:
|
||||
if config.distribute:
|
||||
if os.getenv("DEVICE_ID", "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
|
||||
context.set_context(device_id=get_device_id())
|
||||
init()
|
||||
device_num = args_opt.device_num
|
||||
rank = get_rank()
|
||||
device_num = get_device_num()
|
||||
rank = get_rank_id()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
mindrecord_file = os.path.join(config.mindrecord_dir, "retinanet.mindrecord0")
|
||||
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
loss_scale = float(config.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as retinanet.mindrecord0.
|
||||
dataset = create_retinanet_dataset(mindrecord_file, repeat_num=1,
|
||||
num_parallel_workers=args_opt.workers,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
num_parallel_workers=config.workers,
|
||||
batch_size=config.batch_size, device_num=device_num, rank=rank)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
|
||||
|
||||
backbone = resnet50(config.num_classes)
|
||||
retinanet = retinanet50(backbone, config)
|
||||
net = retinanetWithLossCell(retinanet, config)
|
||||
init_net_param(net)
|
||||
|
||||
if args_opt.pre_trained:
|
||||
if args_opt.pre_trained_epoch_size <= 0:
|
||||
if config.pre_trained:
|
||||
if config.pre_trained_epoch_size <= 0:
|
||||
raise KeyError("pre_trained_epoch_size must be greater than 0.")
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
if args_opt.filter_weight:
|
||||
param_dict = load_checkpoint(config.pre_trained)
|
||||
if config.filter_weight:
|
||||
filter_checkpoint_parameter(param_dict)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
lr = Tensor(get_lr(global_step=config.global_step,
|
||||
lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr,
|
||||
lr_init=config.lr_init, lr_end=config.lr_end_rate * config.lr, lr_max=config.lr,
|
||||
warmup_epochs1=config.warmup_epochs1, warmup_epochs2=config.warmup_epochs2,
|
||||
warmup_epochs3=config.warmup_epochs3, warmup_epochs4=config.warmup_epochs4,
|
||||
warmup_epochs5=config.warmup_epochs5, total_epochs=args_opt.epoch_size,
|
||||
warmup_epochs5=config.warmup_epochs5, total_epochs=config.epoch_size,
|
||||
steps_per_epoch=dataset_size))
|
||||
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
|
||||
config.momentum, config.weight_decay, loss_scale)
|
||||
|
@ -135,16 +176,17 @@ def main():
|
|||
print("Start train retinanet, the first epoch will be slower because of the graph compilation.")
|
||||
cb = [TimeMonitor(), LossMonitor()]
|
||||
cb += [Monitor(lr_init=lr.asnumpy())]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs,
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * config.save_checkpoint_epochs,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="retinanet", directory=config.save_checkpoint_path, config=config_ck)
|
||||
if args_opt.distribute:
|
||||
if config.distribute:
|
||||
if rank == 0:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
else:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue