forked from mindspore-Ecosystem/mindspore
!18354 simple_pose can been used on ModelArts.
Merge pull request !18354 from 郑彬/master
This commit is contained in:
commit
ed584cafcf
|
@ -63,8 +63,7 @@ To run the python scripts in the repository, you need to prepare the environment
|
|||
- Prepare hardware environment with Ascend.
|
||||
- Python and dependencies
|
||||
- python 3.7
|
||||
- mindspore 1.0.1
|
||||
- easydict 1.9
|
||||
- mindspore 1.2.0
|
||||
- opencv-python 4.3.0.36
|
||||
- pycocotools 2.0
|
||||
- For more information, please check the resources below:
|
||||
|
@ -97,23 +96,77 @@ Before you start your training process, you need to obtain mindspore imagenet pr
|
|||
|
||||
## [Running](#contents)
|
||||
|
||||
To train the model, run the shell script `scripts/train_standalone.sh` with the format below:
|
||||
- running on local
|
||||
|
||||
```shell
|
||||
sh scripts/train_standalone.sh [device_id] [ckpt_path_to_save]
|
||||
```
|
||||
To train the model, run the shell script `scripts/train_standalone.sh` with the format below:
|
||||
|
||||
To validate the model, change the settings in `src/config.py` to the path of the model you want to validate. For example:
|
||||
```shell
|
||||
sh scripts/train_standalone.sh [CKPT_SAVE_DIR] [DEVICE_ID] [BATCH_SIZE]
|
||||
```
|
||||
|
||||
```python
|
||||
config.TEST.MODEL_FILE='results/xxxx.ckpt'
|
||||
```
|
||||
To validate the model, change the settings in `default_config.yaml` to the path of the model you want to validate or setting that on the terminal. For example:
|
||||
|
||||
Then, run the shell script `scripts/eval.sh` with the format below:
|
||||
```python
|
||||
TEST:
|
||||
...
|
||||
MODEL_FILE : './{path}/xxxx.ckpt'
|
||||
```
|
||||
|
||||
```shell
|
||||
sh scripts/eval.sh [device_id]
|
||||
```
|
||||
Then, run the shell script `scripts/eval.sh` with the format below:
|
||||
|
||||
```shell
|
||||
sh scripts/eval.sh [TEST_MODEL_FILE] [COCO_BBOX_FILE] [DEVICE_ID]
|
||||
```
|
||||
|
||||
- running on ModelArts
|
||||
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
|
||||
|
||||
- Training with 8 cards on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/simple_pose" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/simple_pose/train.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/simple_pose/default_config.yaml.
|
||||
# 1. Set ”run_distributed: True“
|
||||
# 2. Set ”enable_modelarts: True“
|
||||
# 3. Set “batch_size: 64”(It's not necessary)
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”run_distributed=True“
|
||||
# 2. Add ”enable_modelarts=True“
|
||||
# 3. Add “batch_size=64”(It's not necessary)
|
||||
# (6) Upload the dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path.
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of 8 cards.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
- evaluating with single card on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/simple_pose" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/simple_pose/eval.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/simple_pose/default_config.yaml.
|
||||
# 1. Set ”enable_modelarts: True“
|
||||
# 2. Set “eval_model_file: ./{path}/*.ckpt”('eval_model_file' indicates the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.)
|
||||
# 3. Set ”coco_bbox_file: ./{path}/COCO_val2017_detections_AP_H_56_person.json"(The same as 'eval_model_file')
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”enable_modelarts=True“
|
||||
# 2. Add “eval_model_file=./{path}/*.ckpt”('eval_model_file' indicates the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.)
|
||||
# 3. Add ”coco_bbox_file=./{path}/COCO_val2017_detections_AP_H_56_person.json"(The same as 'eval_model_file')
|
||||
# (6) Upload the dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of a single card.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
|
@ -122,80 +175,95 @@ sh scripts/eval.sh [device_id]
|
|||
The structure of the files in this repository is shown below.
|
||||
|
||||
```text
|
||||
└─ mindspore-simpleposenet
|
||||
├─ scripts
|
||||
│ ├─ eval.sh // launch ascend standalone evaluation
|
||||
│ ├─ train_distributed.sh // launch ascend distributed training
|
||||
│ └─ train_standalone.sh // launch ascend standalone training
|
||||
├─ src
|
||||
│ ├─utils
|
||||
│ │ ├─ transform.py // utils about image transformation
|
||||
│ │ └─ nms.py // utils about nms
|
||||
│ ├─evaluate
|
||||
│ │ └─ coco_eval.py // evaluate result by coco
|
||||
│ ├─ config.py // network and running config
|
||||
│ ├─ dataset.py // dataset processor and provider
|
||||
│ ├─ model.py // SimplePoseNet implementation
|
||||
│ ├─ network_define.py // define loss
|
||||
│ └─ predict.py // predict keypoints from heatmaps
|
||||
├─ eval.py // evaluation script
|
||||
├─ param_convert.py // model parameters conversion script
|
||||
├─ train.py // training script
|
||||
└─ README.md // descriptions about this repository
|
||||
└─ simple_pose
|
||||
├─ scripts
|
||||
│ ├─ eval.sh // launch ascend standalone evaluation
|
||||
│ ├─ train_distributed.sh // launch ascend distributed training
|
||||
│ └─ train_standalone.sh // launch ascend standalone training
|
||||
├─ src
|
||||
│ ├─ utils
|
||||
│ │ ├─ transform.py // utils about image transformation
|
||||
│ │ └─ nms.py // utils about nms
|
||||
│ ├─ evaluate
|
||||
│ │ └─ coco_eval.py // evaluate result by coco
|
||||
│ ├─ model_utils
|
||||
│ │ ├── config.py // parsing parameter configuration file of "*.yaml"
|
||||
│ │ ├── devcie_adapter.py // local or ModelArts training
|
||||
│ │ ├── local_adapter.py // get related environment variables in local training
|
||||
│ │ └── moxing_adapter.py // get related environment variables in ModelArts training
|
||||
│ ├─ dataset.py // dataset processor and provider
|
||||
│ ├─ model.py // SimplePoseNet implementation
|
||||
│ ├─ network_define.py // define loss
|
||||
│ └─ predict.py // predict keypoints from heatmaps
|
||||
├─ default_config.yaml // parameter configuration
|
||||
├─ eval.py // evaluation script
|
||||
├─ train.py // training script
|
||||
└─ README.md // descriptions about this repository
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Configurations for both training and evaluation are set in `src/config.py`. All the settings are shown following.
|
||||
Configurations for both training and evaluation are set in `default_config.yaml`. All the settings are shown following.
|
||||
|
||||
- config for SimplePoseNet on COCO2017 dataset:
|
||||
|
||||
```python
|
||||
# pose_resnet related params
|
||||
POSE_RESNET.HEATMAP_SIZE = [48, 64] # heatmap size
|
||||
POSE_RESNET.SIGMA = 2 # Gaussian hyperparameter in heatmap generation
|
||||
POSE_RESNET.FINAL_CONV_KERNEL = 1 # final convolution kernel size
|
||||
POSE_RESNET.DECONV_WITH_BIAS = False # deconvolution bias
|
||||
POSE_RESNET.NUM_DECONV_LAYERS = 3 # the number of deconvolution layers
|
||||
POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256] # the filter size of deconvolution layers
|
||||
POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4] # kernel size of deconvolution layers
|
||||
POSE_RESNET.NUM_LAYERS = 50 # number of layers(for resnet)
|
||||
# common params for NETWORK
|
||||
config.MODEL.NAME = 'pose_resnet' # model name
|
||||
config.MODEL.INIT_WEIGHTS = True # init model weights by resnet
|
||||
config.MODEL.PRETRAINED = './models/resnet50.ckpt' # pretrained model
|
||||
config.MODEL.NUM_JOINTS = 17 # the number of keypoints
|
||||
config.MODEL.IMAGE_SIZE = [192, 256] # image size
|
||||
# dataset
|
||||
config.DATASET.ROOT = '/data/coco2017/' # coco2017 dataset root
|
||||
config.DATASET.TEST_SET = 'val2017' # folder name of test set
|
||||
config.DATASET.TRAIN_SET = 'train2017' # folder name of train set
|
||||
# data augmentation
|
||||
config.DATASET.FLIP = True # random flip
|
||||
config.DATASET.ROT_FACTOR = 40 # random rotation
|
||||
config.DATASET.SCALE_FACTOR = 0.3 # random scale
|
||||
# for train
|
||||
config.TRAIN.BATCH_SIZE = 64 # batch size
|
||||
config.TRAIN.BEGIN_EPOCH = 0 # begin epoch
|
||||
config.TRAIN.END_EPOCH = 140 # end epoch
|
||||
config.TRAIN.LR = 0.001 # initial learning rate
|
||||
config.TRAIN.LR_FACTOR = 0.1 # learning rate reduce factor
|
||||
config.TRAIN.LR_STEP = [90,120] # step to reduce lr
|
||||
# test
|
||||
config.TEST.BATCH_SIZE = 32 # batch size
|
||||
config.TEST.FLIP_TEST = True # flip test
|
||||
config.TEST.POST_PROCESS = True # post process
|
||||
config.TEST.SHIFT_HEATMAP = True # shift heatmap
|
||||
config.TEST.USE_GT_BBOX = False # use groundtruth bbox
|
||||
config.TEST.MODEL_FILE = '' # model file to test
|
||||
# detect bbox file
|
||||
config.TEST.COCO_BBOX_FILE = 'experiments/COCO_val2017_detections_AP_H_56_person.json'
|
||||
# nms
|
||||
config.TEST.OKS_THRE = 0.9 # oks threshold
|
||||
config.TEST.IN_VIS_THRE = 0.2 # visible threshold
|
||||
config.TEST.BBOX_THRE = 1.0 # bbox threshold
|
||||
config.TEST.IMAGE_THRE = 0.0 # image threshold
|
||||
config.TEST.NMS_THRE = 1.0 # nms threshold
|
||||
# These parameters can be modified at the terminal
|
||||
ckpt_save_dir: 'checkpoints' # the folder to save the '*.ckpt' file
|
||||
batch_size: 128 # TRAIN.BATCH_SIZE
|
||||
run_distribute: False # training by several devices: "true"(training by more than 1 device) | "false", default is "false"
|
||||
eval_model_file: '' # TEST.MODEL_FILE
|
||||
coco_bbox_file: '' # TEST.COCO_BBOX_FILE
|
||||
#pose_resnet-related
|
||||
POSE_RESNET:
|
||||
NUM_LAYERS: 50 # number of layers(for resnet)
|
||||
DECONV_WITH_BIAS: False # deconvolution bias
|
||||
NUM_DECONV_LAYERS: 3 # the number of deconvolution layers
|
||||
NUM_DECONV_FILTERS: [256, 256, 256] # the filter size of deconvolution layers
|
||||
NUM_DECONV_KERNELS: [4, 4, 4] # kernel size of deconvolution layers
|
||||
FINAL_CONV_KERNEL: 1 # final convolution kernel size
|
||||
TARGET_TYPE: 'gaussian'
|
||||
HEATMAP_SIZE: [48, 64] # heatmap size
|
||||
SIGMA: 2 # Gaussian hyperparameter in heatmap generation
|
||||
#network-related
|
||||
MODEL:
|
||||
NAME: 'pose_resnet' # model name
|
||||
INIT_WEIGHTS: True # init model weights by resnet
|
||||
PRETRAINED: './resnet50.ckpt' # pretrained model
|
||||
NUM_JOINTS: 17 # the number of keypoints
|
||||
IMAGE_SIZE: [192, 256] # image size
|
||||
#dataset-related
|
||||
DATASET:
|
||||
ROOT: '/data/coco2017/' # coco2017 dataset root
|
||||
TEST_SET: 'val2017' # folder name of test set
|
||||
TRAIN_SET: 'train2017' # folder name of train set
|
||||
FLIP: True # random flip
|
||||
ROT_FACTOR: 40 # random rotation
|
||||
SCALE_FACTOR: 0.3 # random scale
|
||||
#train-related
|
||||
TRAIN:
|
||||
BATCH_SIZE: 64 # batch size
|
||||
BEGIN_EPOCH: 0 # begin epoch
|
||||
END_EPOCH: 140 # end epoch
|
||||
LR: 0.001 # initial learning rate
|
||||
LR_FACTOR: 0.1 # learning rate reduce factor
|
||||
LR_STEP: [90, 120] # step to reduce lr
|
||||
#eval-related
|
||||
TEST:
|
||||
BATCH_SIZE: 32 # batch size
|
||||
FLIP_TEST: True # flip test
|
||||
POST_PROCESS: True # post process
|
||||
SHIFT_HEATMAP: True # shift heatmap
|
||||
USE_GT_BBOX: False # use groundtruth bbox
|
||||
MODEL_FILE: '' # model file to test
|
||||
DATALOADER_WORKERS: 8
|
||||
COCO_BBOX_FILE: 'experiments/COCO_val2017_detections_AP_H_56_person.json'
|
||||
#nms-related
|
||||
OKS_THRE: 0.9 # oks threshold
|
||||
IN_VIS_THRE: 0.2 # visible threshold
|
||||
BBOX_THRE: 1.0 # bbox threshold
|
||||
IMAGE_THRE: 0.0 # image threshold
|
||||
NMS_THRE: 1.0 # nms threshold
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
@ -207,13 +275,13 @@ config.TEST.NMS_THRE = 1.0 # nms threshold
|
|||
Run `scripts/train_standalone.sh` to train the model standalone. The usage of the script is:
|
||||
|
||||
```shell
|
||||
sh scripts/train_standalone.sh [device_id] [ckpt_path_to_save]
|
||||
sh scripts/train_standalone.sh [CKPT_SAVE_DIR] [DEVICE_ID] [BATCH_SIZE]
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the training procedure.
|
||||
|
||||
```shell
|
||||
sh scripts/train_standalone.sh 0 results/standalone/
|
||||
sh scripts/train_standalone.sh results/standalone/ 0 128
|
||||
```
|
||||
|
||||
The script will run training in the background, you can view the results through the file `train_log[X].txt` as follows:
|
||||
|
@ -232,7 +300,7 @@ Epoch time: 456265.617, per step time: 389.971
|
|||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved into `[ckpt_path_to_save]`.
|
||||
The model checkpoint will be saved into `[CKPT_SAVE_DIR]`.
|
||||
|
||||
### [Distributed Training](#contents)
|
||||
|
||||
|
@ -241,7 +309,7 @@ The model checkpoint will be saved into `[ckpt_path_to_save]`.
|
|||
Run `scripts/train_distributed.sh` to train the model distributed. The usage of the script is:
|
||||
|
||||
```shell
|
||||
sh scripts/train_distributed.sh [rank_table] [ckpt_path_to_save] [device_number]
|
||||
sh scripts/train_distributed.sh [MINDSPORE_HCCL_CONFIG_PATH] [CKPT_SAVE_DIR] [RANK_SIZE]
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the distributed training procedure.
|
||||
|
@ -266,28 +334,22 @@ Epoch time: 164792.001, per step time: 281.696
|
|||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved into `[ckpt_path_to_save]`.
|
||||
The model checkpoint will be saved into `[CKPT_SAVE_DIR]`.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Running on Ascend
|
||||
|
||||
Change the settings in `src/config.py` to the path of the model you want to validate. For example:
|
||||
|
||||
```python
|
||||
config.TEST.MODEL_FILE='results/xxxx.ckpt'
|
||||
```
|
||||
|
||||
Then, run `scripts/eval.sh` to evaluate the model with one Ascend processor. The usage of the script is:
|
||||
run `scripts/eval.sh` to evaluate the model with one Ascend processor. The usage of the script is:
|
||||
|
||||
```shell
|
||||
sh scripts/eval.sh [device_id]
|
||||
sh scripts/eval.sh [TEST_MODEL_FILE] [COCO_BBOX_FILE] [DEVICE_ID]
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the validation procedure.
|
||||
|
||||
```shell
|
||||
sh scripts/eval.sh 0
|
||||
sh scripts/eval.sh results/distributed/sim-140_1170.ckpt
|
||||
```
|
||||
|
||||
The above shell command will run validation procedure in the background. You can view the results through the file `eval_log[X].txt`. The result will be achieved as follows:
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
# ==============================================================================
|
||||
# Parameters that can be modified at the terminal
|
||||
ckpt_save_dir: 'checkpoints'
|
||||
batch_size: 64
|
||||
run_distribute: False
|
||||
eval_model_file: ''
|
||||
coco_bbox_file: ''
|
||||
#pose_resnet-related
|
||||
POSE_RESNET:
|
||||
NUM_LAYERS: 50
|
||||
DECONV_WITH_BIAS: False
|
||||
NUM_DECONV_LAYERS: 3
|
||||
NUM_DECONV_FILTERS: [256, 256, 256]
|
||||
NUM_DECONV_KERNELS: [4, 4, 4]
|
||||
FINAL_CONV_KERNEL: 1
|
||||
TARGET_TYPE: 'gaussian'
|
||||
HEATMAP_SIZE: [48, 64]
|
||||
SIGMA: 2
|
||||
#network-related
|
||||
MODEL:
|
||||
NAME: 'pose_resnet'
|
||||
INIT_WEIGHTS: True
|
||||
PRETRAINED: './resnet50.ckpt'
|
||||
NUM_JOINTS: 17
|
||||
IMAGE_SIZE: [192, 256] # width * height, ex: 192 * 256
|
||||
#dataset-related
|
||||
DATASET:
|
||||
ROOT: '/data/coco2017/'
|
||||
TEST_SET: 'val2017'
|
||||
TRAIN_SET: 'train2017'
|
||||
FLIP: True
|
||||
ROT_FACTOR: 40
|
||||
SCALE_FACTOR: 0.3
|
||||
#train-related
|
||||
TRAIN:
|
||||
BATCH_SIZE: 64
|
||||
BEGIN_EPOCH: 0
|
||||
END_EPOCH: 140
|
||||
LR: 0.001
|
||||
LR_FACTOR: 0.1
|
||||
LR_STEP: [90, 120]
|
||||
#eval-related
|
||||
TEST:
|
||||
BATCH_SIZE: 32
|
||||
FLIP_TEST: True
|
||||
POST_PROCESS: True
|
||||
SHIFT_HEATMAP: True
|
||||
USE_GT_BBOX: False
|
||||
MODEL_FILE: ''
|
||||
DATALOADER_WORKERS: 8
|
||||
COCO_BBOX_FILE: 'experiments/COCO_val2017_detections_AP_H_56_person.json'
|
||||
#nms-related
|
||||
OKS_THRE: 0.9
|
||||
IN_VIS_THRE: 0.2
|
||||
BBOX_THRE: 1.0
|
||||
IMAGE_THRE: 0.0
|
||||
NMS_THRE: 1.0
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
# Parameters that can be modified at the terminal
|
||||
ckpt_save_dir: "ckpt path to save"
|
||||
batch_size: "training batch size"
|
||||
run_distribute: "Run distribute, default is false."
|
||||
|
||||
|
|
@ -12,77 +12,22 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
||||
from mindspore import Tensor, float32, context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config
|
||||
from src.dataset import flip_pairs, keypoint_dataset
|
||||
from src.evaluate.coco_eval import evaluate
|
||||
from src.model import get_pose_net
|
||||
from src.utils.transform import flip_back
|
||||
from src.predict import get_final_preds
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train keypoints network')
|
||||
parser.add_argument("--train_url", type=str, default="", help="")
|
||||
parser.add_argument("--data_url", type=str, default="", help="data")
|
||||
# output
|
||||
parser.add_argument('--output-url',
|
||||
help='output dir',
|
||||
type=str)
|
||||
# training
|
||||
parser.add_argument('--workers',
|
||||
help='num of dataloader workers',
|
||||
default=8,
|
||||
type=int)
|
||||
parser.add_argument('--model-file',
|
||||
help='model state file',
|
||||
type=str)
|
||||
parser.add_argument('--use-detect-bbox',
|
||||
help='use detect bbox',
|
||||
action='store_true')
|
||||
parser.add_argument('--flip-test',
|
||||
help='use flip test',
|
||||
default=True,
|
||||
action='store_true')
|
||||
parser.add_argument('--post-process',
|
||||
help='use post process',
|
||||
action='store_true')
|
||||
parser.add_argument('--shift-heatmap',
|
||||
help='shift heatmap',
|
||||
action='store_true')
|
||||
parser.add_argument('--coco-bbox-file',
|
||||
help='coco detection bbox file',
|
||||
type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def reset_config(cfg, args):
|
||||
if args.use_detect_bbox:
|
||||
cfg.TEST.USE_GT_BBOX = not args.use_detect_bbox
|
||||
if args.flip_test:
|
||||
cfg.TEST.FLIP_TEST = args.flip_test
|
||||
print('use flip test:', cfg.TEST.FLIP_TEST)
|
||||
if args.post_process:
|
||||
cfg.TEST.POST_PROCESS = args.post_process
|
||||
if args.shift_heatmap:
|
||||
cfg.TEST.SHIFT_HEATMAP = args.shift_heatmap
|
||||
if args.model_file:
|
||||
cfg.TEST.MODEL_FILE = args.model_file
|
||||
if args.coco_bbox_file:
|
||||
cfg.TEST.COCO_BBOX_FILE = args.coco_bbox_file
|
||||
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
def validate(cfg, val_dataset, model, output_dir):
|
||||
# switch to evaluate mode
|
||||
|
@ -141,38 +86,43 @@ def validate(cfg, val_dataset, model, output_dir):
|
|||
print("AP:", perf_indicator)
|
||||
return perf_indicator
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
config.TEST.MODEL_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), config.TEST.MODEL_FILE)
|
||||
config.DATASET.ROOT = config.data_path
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def main():
|
||||
# init seed
|
||||
set_seed(1)
|
||||
|
||||
# set context
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_id = get_device_id()
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
args = parse_args()
|
||||
# update config
|
||||
reset_config(config, args)
|
||||
|
||||
# init model
|
||||
model = get_pose_net(config, is_train=False)
|
||||
|
||||
# load parameters
|
||||
ckpt_name = config.TEST.MODEL_FILE
|
||||
print('loading model ckpt from {}'.format(ckpt_name))
|
||||
load_param_into_net(model, load_checkpoint(ckpt_name))
|
||||
ckpt_file = config.TEST.MODEL_FILE
|
||||
print('loading model ckpt from {}'.format(ckpt_file))
|
||||
load_param_into_net(model, load_checkpoint(ckpt_file))
|
||||
|
||||
# Data loading code
|
||||
valid_dataset, _ = keypoint_dataset(
|
||||
config,
|
||||
bbox_file=config.TEST.COCO_BBOX_FILE,
|
||||
train_mode=False,
|
||||
num_parallel_workers=args.workers,
|
||||
num_parallel_workers=config.TEST.DATALOADER_WORKERS,
|
||||
)
|
||||
|
||||
# evaluate on validation set
|
||||
validate(config, valid_dataset, model, ckpt_name.split('.')[0])
|
||||
output_dir = ckpt_file.split('.')[0]
|
||||
if config.enable_modelarts:
|
||||
output_dir = config.output_path
|
||||
validate(config, valid_dataset, model, output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -13,6 +13,36 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
export DEVICE_ID=$1
|
||||
echo "$1 $2 $3"
|
||||
|
||||
python eval.py > eval_log$1.txt 2>&1 &
|
||||
if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash train_standalone.sh [TEST_MODEL_FILE] [COCO_BBOX_FILE] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DEVICE_ID=0
|
||||
|
||||
if [ $# -ge 3 ]
|
||||
then
|
||||
expr $3 + 6 &>/dev/null
|
||||
if [ $? != 0 ]
|
||||
then
|
||||
echo "error:DEVICE_ID=$3 is not a integer"
|
||||
exit 1
|
||||
fi
|
||||
DEVICE_ID=$3
|
||||
fi
|
||||
|
||||
export DEVICE_ID=$DEVICE_ID
|
||||
|
||||
rm -rf ./eval
|
||||
mkdir ./eval
|
||||
echo "start evaluating for device $DEVICE_ID"
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
cd ../
|
||||
python eval.py \
|
||||
--eval_model_file=$1 --coco_bbox_file=$2\
|
||||
> ./eval/eval_log.txt 2>&1 &
|
||||
echo "python eval.py --eval_model_file=$1 --coco_bbox_file=$2 > ./eval/eval_log.txt 2>&1 &"
|
||||
|
|
|
@ -13,12 +13,21 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# Usage: sh train_distributed.sh [MINDSPORE_HCCL_CONFIG_PATH] [SAVE_CKPT_PATH] [RANK_SIZE]
|
||||
# Usage: sh train_distributed.sh [MINDSPORE_HCCL_CONFIG_PATH] [CKPT_SAVE_DIR] [RANK_SIZE]
|
||||
echo "$1 $2 $3"
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh train_distributed.sh [MINDSPORE_HCCL_CONFIG_PATH] [CKPT_SAVE_DIR] [RANK_SIZE]"
|
||||
exit 1
|
||||
fi
|
||||
if [ ! -f $1 ]; then
|
||||
echo "error: MINDSPORE_HCCL_CONFIG_PATH=$1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
export RANK_TABLE_FILE=$1
|
||||
echo "RANK_TABLE_FILE=$RANK_TABLE_FILE"
|
||||
export RANK_SIZE=$3
|
||||
SAVE_PATH=$2
|
||||
CKPT_SAVE_DIR=$2
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
|
@ -33,11 +42,9 @@ do
|
|||
env > env.log
|
||||
cd ../
|
||||
python train.py \
|
||||
--run-distribute \
|
||||
--ckpt-path=$SAVE_PATH > train_parallel$i/log.txt 2>&1 &
|
||||
|
||||
--ckpt_save_dir=$CKPT_SAVE_DIR \
|
||||
--run_distribute=True > ./train_parallel$i/log.txt 2>&1 &
|
||||
echo "python train.py \
|
||||
--run-distribute \
|
||||
--ckpt-path=$SAVE_PATH > train_parallel$i/log.txt 2>&1 &"
|
||||
|
||||
--ckpt_save_dir=$CKPT_SAVE_DIR \
|
||||
--run_distribute=True > ./train_parallel$i/log.txt 2>&1 &"
|
||||
done
|
||||
|
|
|
@ -13,10 +13,50 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# Usage: train_standalone.sh [DEVICE_ID] [SAVE_CKPT_PATH]
|
||||
export DEVICE_ID=$1
|
||||
# Usage: train_standalone.sh [CKPT_SAVE_DIR] [DEVICE_ID]
|
||||
echo "$1 $2 $3"
|
||||
|
||||
if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash train_standalone.sh [CKPT_SAVE_DIR] [DEVICE_ID] [BATCH_SIZE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DEVICE_ID=0
|
||||
|
||||
if [ $# -ge 2 ]
|
||||
then
|
||||
expr $2 + 6 &>/dev/null
|
||||
if [ $? != 0 ]
|
||||
then
|
||||
echo "error:DEVICE_ID=$2 is not a integer"
|
||||
exit 1
|
||||
fi
|
||||
DEVICE_ID=$2
|
||||
fi
|
||||
|
||||
BATCH_SIZE=128
|
||||
|
||||
if [ $# -ge 3 ]
|
||||
then
|
||||
expr $3 + 6 &>/dev/null
|
||||
if [ $? != 0 ]
|
||||
then
|
||||
echo "error:BATCH_SIZE=$3 is not a integer"
|
||||
exit 1
|
||||
fi
|
||||
BATCH_SIZE=$3
|
||||
fi
|
||||
|
||||
export DEVICE_ID=$DEVICE_ID
|
||||
|
||||
rm -rf ./train_single
|
||||
mkdir ./train_single
|
||||
echo "start training for rank 0, device $DEVICE_ID"
|
||||
cd ./train_single ||exit
|
||||
env >env.log
|
||||
cd ../
|
||||
python train.py \
|
||||
--ckpt-path=$2 --batch-size=128\
|
||||
> train_log$1.txt 2>&1 &
|
||||
echo " python train.py --ckpt-path=$2 --batch-size=128 > train_log$1.txt 2>&1 &"
|
||||
--ckpt_save_dir=$1 --batch_size=$BATCH_SIZE\
|
||||
> ./train_single/train_log.txt 2>&1 &
|
||||
echo " python train.py --ckpt_save_dir=$1 --batch_size=$BATCH_SIZE > ./train_single/train_log.txt 2>&1 &"
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
config = edict()
|
||||
|
||||
# pose_resnet related params
|
||||
POSE_RESNET = edict()
|
||||
POSE_RESNET.NUM_LAYERS = 50
|
||||
POSE_RESNET.DECONV_WITH_BIAS = False
|
||||
POSE_RESNET.NUM_DECONV_LAYERS = 3
|
||||
POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256]
|
||||
POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4]
|
||||
POSE_RESNET.FINAL_CONV_KERNEL = 1
|
||||
POSE_RESNET.TARGET_TYPE = 'gaussian'
|
||||
POSE_RESNET.HEATMAP_SIZE = [48, 64] # width * height, ex: 24 * 32
|
||||
POSE_RESNET.SIGMA = 2
|
||||
|
||||
MODEL_EXTRAS = {
|
||||
'pose_resnet': POSE_RESNET,
|
||||
}
|
||||
|
||||
# common params for NETWORK
|
||||
config.MODEL = edict()
|
||||
config.MODEL.NAME = 'pose_resnet'
|
||||
config.MODEL.INIT_WEIGHTS = True
|
||||
config.MODEL.PRETRAINED = './models/resnet50.ckpt'
|
||||
config.MODEL.NUM_JOINTS = 17
|
||||
config.MODEL.IMAGE_SIZE = [192, 256] # width * height, ex: 192 * 256
|
||||
config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME]
|
||||
|
||||
# dataset
|
||||
config.DATASET = edict()
|
||||
config.DATASET.ROOT = '/data/coco2017/'
|
||||
config.DATASET.TEST_SET = 'val2017'
|
||||
config.DATASET.TRAIN_SET = 'train2017'
|
||||
# data augmentation
|
||||
config.DATASET.FLIP = True
|
||||
config.DATASET.ROT_FACTOR = 40
|
||||
config.DATASET.SCALE_FACTOR = 0.3
|
||||
|
||||
# for train
|
||||
config.TRAIN = edict()
|
||||
config.TRAIN.BATCH_SIZE = 64
|
||||
config.TRAIN.BEGIN_EPOCH = 0
|
||||
config.TRAIN.END_EPOCH = 140
|
||||
config.TRAIN.LR = 0.001
|
||||
config.TRAIN.LR_FACTOR = 0.1
|
||||
config.TRAIN.LR_STEP = [90, 120]
|
||||
|
||||
# test
|
||||
config.TEST = edict()
|
||||
config.TEST.BATCH_SIZE = 32
|
||||
config.TEST.FLIP_TEST = True
|
||||
config.TEST.POST_PROCESS = True
|
||||
config.TEST.SHIFT_HEATMAP = True
|
||||
config.TEST.USE_GT_BBOX = False
|
||||
config.TEST.MODEL_FILE = ''
|
||||
config.TEST.COCO_BBOX_FILE = 'experiments/COCO_val2017_detections_AP_H_56_person.json'
|
||||
# nms
|
||||
config.TEST.OKS_THRE = 0.9
|
||||
config.TEST.IN_VIS_THRE = 0.2
|
||||
config.TEST.BBOX_THRE = 1.0
|
||||
config.TEST.IMAGE_THRE = 0.0
|
||||
config.TEST.NMS_THRE = 1.0
|
|
@ -38,9 +38,9 @@ class KeypointDatasetGenerator:
|
|||
self.image_width = cfg.MODEL.IMAGE_SIZE[0]
|
||||
self.image_height = cfg.MODEL.IMAGE_SIZE[1]
|
||||
self.aspect_ratio = self.image_width * 1.0 / self.image_height
|
||||
self.heatmap_size = np.array(cfg.MODEL.EXTRA.HEATMAP_SIZE, dtype=np.int32)
|
||||
self.sigma = cfg.MODEL.EXTRA.SIGMA
|
||||
self.target_type = cfg.MODEL.EXTRA.TARGET_TYPE
|
||||
self.heatmap_size = np.array(cfg.POSE_RESNET.HEATMAP_SIZE, dtype=np.int32)
|
||||
self.sigma = cfg.POSE_RESNET.SIGMA
|
||||
self.target_type = cfg.POSE_RESNET.TARGET_TYPE
|
||||
|
||||
# data argumentation
|
||||
self.scale_factor = cfg.DATASET.SCALE_FACTOR
|
||||
|
|
|
@ -81,7 +81,7 @@ class PoseResNet(nn.Cell):
|
|||
|
||||
def __init__(self, block, layers, cfg, pytorch_mode=True):
|
||||
self.inplanes = 64
|
||||
extra = cfg.MODEL.EXTRA
|
||||
extra = cfg.POSE_RESNET
|
||||
self.deconv_with_bias = extra.DECONV_WITH_BIAS
|
||||
|
||||
super(PoseResNet, self).__init__()
|
||||
|
@ -214,7 +214,7 @@ resnet_spec = {50: (Bottleneck, [3, 4, 6, 3]),
|
|||
|
||||
|
||||
def get_pose_net(cfg, is_train, ckpt_path=None, pytorch_mode=False):
|
||||
num_layers = cfg.MODEL.EXTRA.NUM_LAYERS
|
||||
num_layers = cfg.POSE_RESNET.NUM_LAYERS
|
||||
|
||||
block_class, layers = resnet_spec[num_layers]
|
||||
model = PoseResNet(block_class, layers, cfg, pytorch_mode=pytorch_mode)
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
# print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def extra_operations(cfg):
|
||||
"""
|
||||
Do extra work on Config object.
|
||||
|
||||
Args:
|
||||
cfg: Object after instantiation of class 'Config'.
|
||||
"""
|
||||
if cfg.eval_model_file:
|
||||
cfg.TEST.MODEL_FILE = cfg.eval_model_file
|
||||
if cfg.coco_bbox_file:
|
||||
cfg.TEST.COCO_BBOX_FILE = cfg.coco_bbox_file
|
||||
if cfg.batch_size:
|
||||
cfg.TRAIN.BATCH_SIZE = cfg.batch_size
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
config_obj = Config(final_config)
|
||||
extra_operations(config_obj)
|
||||
return config_obj
|
||||
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(config)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
# print("os.mknod({}) success".format(sync_lock))
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -13,24 +13,25 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import config
|
||||
from src.model import get_pose_net
|
||||
from src.network_define import JointsMSELoss, WithLossCell
|
||||
from src.dataset import keypoint_dataset
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
set_seed(1)
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
|
||||
def get_lr(begin_epoch,
|
||||
|
@ -65,39 +66,76 @@ def get_lr(begin_epoch,
|
|||
learning_rate = lr_each_step[current_step:]
|
||||
return learning_rate
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simpleposenet training")
|
||||
parser.add_argument("--run-distribute",
|
||||
help="Run distribute, default is false.",
|
||||
action='store_true')
|
||||
parser.add_argument('--ckpt-path', type=str, help='ckpt path to save')
|
||||
parser.add_argument('--batch-size', type=int, help='training batch size')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
if config.modelarts_dataset_unzip_name:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
def main():
|
||||
# load parse and config
|
||||
print("loading parse...")
|
||||
args = parse_args()
|
||||
if args.batch_size:
|
||||
config.TRAIN.BATCH_SIZE = args.batch_size
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
config.ckpt_save_dir = os.path.join(config.output_path, config.ckpt_save_dir)
|
||||
config.DATASET.ROOT = config.data_path
|
||||
config.MODEL.PRETRAINED = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.MODEL.PRETRAINED)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
print('batch size :{}'.format(config.TRAIN.BATCH_SIZE))
|
||||
|
||||
# distribution and context
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=device_id)
|
||||
device_id=get_device_id())
|
||||
|
||||
if args.run_distribute:
|
||||
init()
|
||||
rank = get_rank()
|
||||
device_num = get_group_size()
|
||||
if config.run_distribute:
|
||||
rank = get_rank_id()
|
||||
device_num = get_device_num()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
@ -133,9 +171,11 @@ def main():
|
|||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if args.ckpt_path and rank_save_flag:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size, keep_checkpoint_max=20)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="simplepose", directory=args.ckpt_path, config=config_ck)
|
||||
if config.run_distribute:
|
||||
config.ckpt_save_dir = os.path.join(config.ckpt_save_dir, str(get_rank_id()))
|
||||
if config.ckpt_save_dir and rank_save_flag:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size, keep_checkpoint_max=5)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="simplepose", directory=config.ckpt_save_dir, config=config_ck)
|
||||
cb.append(ckpoint_cb)
|
||||
# train model
|
||||
model = Model(net_with_loss, loss_fn=None, optimizer=opt, amp_level="O2")
|
||||
|
@ -145,4 +185,4 @@ def main():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
run_train()
|
||||
|
|
Loading…
Reference in New Issue