forked from mindspore-Ecosystem/mindspore
!16852 ssd_mobilenetV2_master_PR
Merge pull request !16852 from 陈宇凡/ssd_mobilenetV2_master
This commit is contained in:
commit
3434a6f2c6
|
@ -0,0 +1,324 @@
|
|||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [SSD Description](#ssd-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Prepare the model](#prepare-the-model)
|
||||
- [Run the scripts](#run-the-scripts)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training on Ascend](#training-on-ascend)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation on Ascend](#evaluation-on-ascend)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
## [SSD Description](#contents)
|
||||
|
||||
SSD discretizes the output space of bounding boxes into a set of default boxes over different aspect ratios and scales per feature map location. At prediction time, the network generates scores for the presence of each object category in each default box and produces adjustments to the box to better match the object shape.Additionally, the network combines predictions from multiple feature maps with different resolutions to naturally handle objects of various sizes.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1512.02325): Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, Alexander C. Berg.European Conference on Computer Vision (ECCV), 2016 (In press).
|
||||
|
||||
## [Model Architecture](#contents)
|
||||
|
||||
The SSD approach is based on a feed-forward convolutional network that produces a fixed-size collection of bounding boxes and scores for the presence of object class instances in those boxes, followed by a non-maximum suppression step to produce the final detections. The early network layers are based on a standard architecture used for high quality image classification, which is called the base network. Then add auxiliary structure to the network to produce detections.
|
||||
|
||||
- **ssd320**, reference from the paper. Using mobilenetv2 as backbone and the same bbox predictor as the paper present.
|
||||
|
||||
## [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [COCO2017](<http://images.cocodataset.org/>)
|
||||
|
||||
- Dataset size:19G
|
||||
- Train:18G,118000 images
|
||||
- Val:1G,5000 images
|
||||
- Annotations:241M,instances,captions,person_keypoints etc
|
||||
- Data format:image and json files
|
||||
- Note:Data will be processed in dataset.py
|
||||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
- Download the dataset COCO2017.
|
||||
|
||||
- We use COCO2017 as training dataset in this example by default, and you can also use your own datasets.
|
||||
First, install Cython ,pycocotool and opencv to process data and to get evaluation result.
|
||||
|
||||
```shell
|
||||
pip install Cython
|
||||
|
||||
pip install pycocotools
|
||||
|
||||
pip install opencv-python
|
||||
```
|
||||
|
||||
1. If coco dataset is used. **Select dataset to coco when run script.**
|
||||
|
||||
Change the `coco_root` and other settings you need in `src/config.py`. The directory structure is as follows:
|
||||
|
||||
```shell
|
||||
.
|
||||
└─coco_dataset
|
||||
├─annotations
|
||||
├─instance_train2017.json
|
||||
└─instance_val2017.json
|
||||
├─val2017
|
||||
└─train2017
|
||||
```
|
||||
|
||||
2. If VOC dataset is used. **Select dataset to voc when run script.**
|
||||
Change `classes`, `num_classes`, `voc_json` and `voc_root` in `src/config.py`. `voc_json` is the path of json file with coco format for evaluation, `voc_root` is the path of VOC dataset, the directory structure is as follows:
|
||||
|
||||
```shell
|
||||
.
|
||||
└─voc_dataset
|
||||
└─train
|
||||
├─0001.jpg
|
||||
└─0001.xml
|
||||
...
|
||||
├─xxxx.jpg
|
||||
└─xxxx.xml
|
||||
└─eval
|
||||
├─0001.jpg
|
||||
└─0001.xml
|
||||
...
|
||||
├─xxxx.jpg
|
||||
└─xxxx.xml
|
||||
```
|
||||
|
||||
3. If your own dataset is used. **Select dataset to other when run script.**
|
||||
Organize the dataset information into a TXT file, each row in the file is as follows:
|
||||
|
||||
```shell
|
||||
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
|
||||
```
|
||||
|
||||
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `image_dir`(dataset directory) and the relative path in `anno_path`(the TXT file path), `image_dir` and `anno_path` are setting in `src/config.py`.
|
||||
|
||||
## [Quick Start](#contents)
|
||||
|
||||
### Prepare the model
|
||||
|
||||
Change the dataset config in the config.
|
||||
|
||||
### Run the scripts
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```shell
|
||||
# distributed training on Ascend
|
||||
sh scripts/run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] [RANK_TABLE_FILE]
|
||||
|
||||
# run eval on Ascend
|
||||
sh scripts/run_eval.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
||||
```shell
|
||||
.
|
||||
└─ cv
|
||||
└─ ssd
|
||||
├─ README.md # descriptions about SSD
|
||||
├─ scripts
|
||||
├─ run_distribute_train.sh # shell script for distributed on ascend
|
||||
├─ run_eval.sh # shell script for eval on ascend
|
||||
├─ src
|
||||
├─ __init__.py # init file
|
||||
├─ box_utils.py # bbox utils
|
||||
├─ eval_utils.py # metrics utils
|
||||
├─ config.py # total config
|
||||
├─ dataset.py # create dataset and process dataset
|
||||
├─ init_params.py # parameters utils
|
||||
├─ lr_schedule.py # learning ratio generator
|
||||
└─ ssd.py # ssd architecture
|
||||
├─ eval.py # eval scripts
|
||||
├─ train.py # train scripts
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
```shell
|
||||
Major parameters in train.py and config.py as follows:
|
||||
|
||||
"device_num": 1 # Use device nums
|
||||
"lr": 0.05 # Learning rate init value
|
||||
"dataset": coco # Dataset name
|
||||
"epoch_size": 500 # Epoch size
|
||||
"batch_size": 32 # Batch size of input tensor
|
||||
"pre_trained": None # Pretrained checkpoint file path
|
||||
"pre_trained_epoch_size": 0 # Pretrained epoch size
|
||||
"save_checkpoint_epochs": 10 # The epoch interval between two checkpoints. By default, the checkpoint will be saved per 10 epochs
|
||||
"loss_scale": 1024 # Loss scale
|
||||
"filter_weight": False # Load parameters in head layer or not. If the class numbers of train dataset is different from the class numbers in pre_trained checkpoint, please set True.
|
||||
"freeze_layer": "none" # Freeze the backbone parameters or not, support none and backbone.
|
||||
|
||||
"class_num": 81 # Dataset class number
|
||||
"image_shape": [320, 320] # Image height and width used as input to the model
|
||||
"mindrecord_dir": "/data/MindRecord_COCO" # MindRecord path
|
||||
"coco_root": "/data/coco2017" # COCO2017 dataset path
|
||||
"voc_root": "/data/voc_dataset" # VOC original dataset path
|
||||
"voc_json": "annotations/voc_instances_val.json" # is the path of json file with coco format for evaluation
|
||||
"image_dir": "" # Other dataset image path, if coco or voc used, it will be useless
|
||||
"anno_path": "" # Other dataset annotation path, if coco or voc used, it will be useless
|
||||
|
||||
```
|
||||
|
||||
### [Training Process](#contents)
|
||||
|
||||
To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/convert_dataset.html) files by `coco_root`(coco dataset), `voc_root`(voc dataset) or `image_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.**
|
||||
|
||||
#### Training on Ascend
|
||||
|
||||
- Distribute mode
|
||||
|
||||
```shell
|
||||
sh scripts/run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] [RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)
|
||||
```
|
||||
|
||||
We need five or seven parameters for this scripts.
|
||||
|
||||
- `DEVICE_NUM`: the device number for distributed train.
|
||||
- `EPOCH_NUM`: epoch num for distributed train.
|
||||
- `LR`: learning rate init value for distributed train.
|
||||
- `DATASET`:the dataset mode for distributed train.
|
||||
- `RANK_TABLE_FILE :` the path of [rank_table.json](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools), it is better to use absolute path.
|
||||
- `PRE_TRAINED :` the path of pretrained checkpoint file, it is better to use absolute path.
|
||||
- `PRE_TRAINED_EPOCH_SIZE :` the epoch num of pretrained.
|
||||
|
||||
Training result will be stored in the current path, whose folder name begins with "LOG". Under this, you can find checkpoint file together with result like the followings in log
|
||||
|
||||
```shell
|
||||
epoch: 1 step: 458, loss is 2.329789
|
||||
epoch time: 522433.474 ms, per step time: 1140.684 ms
|
||||
epoch: 2 step: 458, loss is 2.1185513
|
||||
epoch time: 32531.105 ms, per step time: 71.029 ms
|
||||
epoch: 3 step: 458, loss is 1.9073256
|
||||
epoch time: 32643.957 ms, per step time: 71.275 ms
|
||||
...
|
||||
|
||||
epoch: 498 step: 458, loss is 0.6682728
|
||||
epoch time: 31163.108 ms, per step time: 68.042 ms
|
||||
epoch: 499 step: 458, loss is 0.8796004
|
||||
epoch time: 31107.760 ms, per step time: 67.921 ms
|
||||
epoch: 500 step: 458, loss is 0.7718496
|
||||
epoch time: 32848.501 ms, per step time: 71.722 ms
|
||||
```
|
||||
|
||||
- single mode
|
||||
|
||||
```shell
|
||||
sh scripts/run_1p_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] [DATASET] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)
|
||||
```
|
||||
|
||||
We need five or seven parameters for this scripts.
|
||||
|
||||
- `DEVICE_ID`: the device ID for train.
|
||||
- `EPOCH_NUM`: epoch num for distributed train.
|
||||
- `LR`: learning rate init value for distributed train.
|
||||
- `DATASET`:the dataset mode for distributed train.
|
||||
- `PRE_TRAINED :` the path of pretrained checkpoint file, it is better to use absolute path.
|
||||
- `PRE_TRAINED_EPOCH_SIZE :` the epoch num of pretrained.
|
||||
|
||||
Training result will be stored in the current path, whose folder name begins with "LOG". Under this, you can find checkpoint file together with result like the followings in log
|
||||
|
||||
```shell
|
||||
epoch: 1 step: 3664, loss is 2.1746433
|
||||
epoch time: 383006.976 ms, per step time: 104.532 ms
|
||||
epoch: 2 step: 3664, loss is 2.1719098
|
||||
epoch time: 227088.618 ms, per step time: 61.978 ms
|
||||
```
|
||||
|
||||
### [Evaluation Process](#contents)
|
||||
|
||||
#### Evaluation on Ascend
|
||||
|
||||
```shell
|
||||
sh scripts/run_eval.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID]
|
||||
```
|
||||
|
||||
We need two parameters for this scripts.
|
||||
|
||||
- `DATASET`:the dataset mode of evaluation dataset.
|
||||
- `CHECKPOINT_PATH`: the absolute path for checkpoint file.
|
||||
- `DEVICE_ID`: the device id for eval.
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
Inference result will be stored in the example path, whose folder name begins with "eval". Under this, you can find result like the followings in log.
|
||||
|
||||
```shell
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.253
|
||||
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.415
|
||||
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.257
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.045
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.222
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.453
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.259
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.405
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.438
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.131
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.457
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.704
|
||||
|
||||
========================================
|
||||
|
||||
mAP: 0.2527925497483538
|
||||
```
|
||||
|
||||
## [Model Description](#contents)
|
||||
|
||||
### [Performance](#contents)
|
||||
|
||||
#### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | ------------------- |
|
||||
| Model Version | SSD mobielnetV2 |
|
||||
| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G|
|
||||
| uploaded Date | 03/12/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | COCO2017 |
|
||||
| Training Parameters | epoch = 500, batch_size = 32 |
|
||||
| Optimizer | Momentum |
|
||||
| Loss Function | Sigmoid Cross Entropy,SmoothL1Loss |
|
||||
| Speed | 8pcs: 80ms/step |
|
||||
| Total time | 8pcs: 4.67hours |
|
||||
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/SSD_mobielnetV2> |
|
||||
|
||||
#### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | SSD mobilenetV2 |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 03/12/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | COCO2017 |
|
||||
| batch_size | 1 |
|
||||
| outputs | mAP |
|
||||
| Accuracy | IoU=0.50: 25.28% |
|
||||
|
||||
## [Description of Random Situation](#contents)
|
||||
|
||||
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
|
||||
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation for SSD"""
|
||||
|
||||
import ast
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.ssd import SSD320, SsdInferWithDecoder, ssd_mobilenet_v2
|
||||
from src.dataset import create_ssd_dataset, create_mindrecord
|
||||
from src.config import config
|
||||
from src.eval_utils import metrics
|
||||
from src.box_utils import default_boxes
|
||||
|
||||
def ssd_eval(dataset_path, ckpt_path, anno_json):
|
||||
"""SSD evaluation."""
|
||||
batch_size = 1
|
||||
ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1,
|
||||
is_training=False, use_multiprocessing=False)
|
||||
net = SSD320(ssd_mobilenet_v2(), config, is_training=False)
|
||||
net = SsdInferWithDecoder(net, Tensor(default_boxes), config)
|
||||
|
||||
print("Load Checkpoint!")
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
net.init_parameters_data()
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
net.set_train(False)
|
||||
i = batch_size
|
||||
total = ds.get_dataset_size() * batch_size
|
||||
start = time.time()
|
||||
pred_data = []
|
||||
print("\n========================================\n")
|
||||
print("total images num: ", total)
|
||||
print("Processing, please wait a moment.")
|
||||
for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
img_id = data['img_id']
|
||||
img_np = data['image']
|
||||
image_shape = data['image_shape']
|
||||
|
||||
output = net(Tensor(img_np))
|
||||
for batch_idx in range(img_np.shape[0]):
|
||||
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
|
||||
"box_scores": output[1].asnumpy()[batch_idx],
|
||||
"img_id": int(np.squeeze(img_id[batch_idx])),
|
||||
"image_shape": image_shape[batch_idx]})
|
||||
percent = round(i / total * 100., 2)
|
||||
|
||||
print(f' {str(percent)} [{i}/{total}]', end='\r')
|
||||
i += batch_size
|
||||
cost_time = int((time.time() - start) * 1000)
|
||||
print(f' 100% [{total}/{total}] cost {cost_time} ms')
|
||||
mAP = metrics(pred_data, anno_json)
|
||||
print("\n========================================\n")
|
||||
print(f"mAP: {mAP}")
|
||||
|
||||
def get_eval_args():
|
||||
"""set arguments"""
|
||||
parser = argparse.ArgumentParser(description='SSD evaluation')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
|
||||
help="run platform, support Ascend.")
|
||||
parser.add_argument('--modelarts_mode', type=ast.literal_eval, default=False,
|
||||
help='train on modelarts or not, default is False')
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
parser.add_argument('--mindrecord_mode', type=str, default="mindrecord", choices=("coco", "mindrecord"),
|
||||
help='type of data, default is mindrecord')
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = get_eval_args()
|
||||
if args_opt.modelarts_mode:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=device_id)
|
||||
config.coco_root = os.path.join(config.coco_root, str(device_id))
|
||||
config.mindrecord_dir = os.path.join(config.mindrecord_dir, str(device_id))
|
||||
checkpoint_path = "/cache/ckpt/"
|
||||
checkpoint_path = os.path.join(checkpoint_path, str(device_id))
|
||||
mox.file.copy_parallel(args_opt.checkpoint_path, checkpoint_path)
|
||||
if args_opt.mindrecord_mode == "mindrecord":
|
||||
mox.file.copy_parallel(args_opt.data_url, config.mindrecord_dir)
|
||||
else:
|
||||
mox.file.copy_parallel(args_opt.data_url, config.coco_root)
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
|
||||
|
||||
mindrecord_file = create_mindrecord(args_opt.dataset, "ssd_eval.mindrecord", False)
|
||||
|
||||
if args_opt.dataset == "coco":
|
||||
json_path = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type))
|
||||
elif args_opt.dataset == "voc":
|
||||
json_path = os.path.join(config.voc_root, config.voc_json)
|
||||
else:
|
||||
raise ValueError('SSD eval only support dataset mode is coco and voc!')
|
||||
print("Start Eval!")
|
||||
if args_opt.modelarts_mode:
|
||||
checkpoint_path = checkpoint_path + '/ssd-500_458.ckpt'
|
||||
ssd_eval(mindrecord_file, checkpoint_path, json_path)
|
||||
mox.file.copy_parallel(config.mindrecord_dir, args_opt.train_url)
|
||||
else:
|
||||
ssd_eval(mindrecord_file, args_opt.checkpoint_path, json_path)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""export"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.ssd import SSD320, SsdInferWithDecoder, ssd_mobilenet_v2
|
||||
from src.config import config
|
||||
from src.box_utils import default_boxes
|
||||
|
||||
parser = argparse.ArgumentParser(description='SSD export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="ssd", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = SSD320(ssd_mobilenet_v2(), config, is_training=False)
|
||||
net = SsdInferWithDecoder(net, Tensor(default_boxes), config)
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
net.init_parameters_data()
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
input_shp = [args.batch_size, 3] + config.img_shape
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp), mstype.float32)
|
||||
export(net, input_array, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,73 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_1p_train.sh DEVICE_ID EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
|
||||
echo "for example: sh run_1p_train.sh 8 500 0.2 coco /opt/ssd-300.ckpt(optional) 200(optional)"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
if [ $# != 4 ] && [ $# != 6 ]
|
||||
then
|
||||
echo "Usage: sh run_1p_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] [DATASET] \
|
||||
[PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Before start 1pc train, first create mindrecord files.
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
python train.py --only_create_dataset=True --dataset=$4
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
LR=$3
|
||||
DATASET=$4
|
||||
PRE_TRAINED=$5
|
||||
PRE_TRAINED_EPOCH_SIZE=$6
|
||||
|
||||
rm -rf LOG$1
|
||||
mkdir ./LOG$1
|
||||
cp ./*.py ./LOG$1
|
||||
cp -r ./src ./LOG$1
|
||||
cp -r ./scripts ./LOG$1
|
||||
cd ./LOG$1 || exit
|
||||
|
||||
echo "start training for device $1"
|
||||
env > env.log
|
||||
if [ $# == 4 ]
|
||||
then
|
||||
python train.py \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
if [ $# == 6 ]
|
||||
then
|
||||
python train.py \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_id=$DEVICE_ID \
|
||||
--pre_trained=$PRE_TRAINED \
|
||||
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
cd ../
|
|
@ -0,0 +1,83 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
|
||||
echo "for example: sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
if [ $# != 5 ] && [ $# != 7 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \
|
||||
[RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Before start distribute train, first create mindrecord files.
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
python train.py --only_create_dataset=True --dataset=$4
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
LR=$3
|
||||
DATASET=$4
|
||||
PRE_TRAINED=$6
|
||||
PRE_TRAINED_EPOCH_SIZE=$7
|
||||
export RANK_TABLE_FILE=$5
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp ./*.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cp -r ./scripts ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
python train.py \
|
||||
--distribute=True \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
if [ $# == 7 ]
|
||||
then
|
||||
python train.py \
|
||||
--distribute=True \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--pre_trained=$PRE_TRAINED \
|
||||
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET=$1
|
||||
CHECKPOINT_PATH=$(get_real_path $2)
|
||||
echo $DATASET
|
||||
echo $CHECKPOINT_PATH
|
||||
|
||||
if [ ! -f $CHECKPOINT_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=$3
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
|
||||
if [ -d "eval$3" ];
|
||||
then
|
||||
rm -rf ./eval$3
|
||||
fi
|
||||
|
||||
mkdir ./eval$3
|
||||
cp ./*.py ./eval$3
|
||||
cp -r ./src ./eval$3
|
||||
cd ./eval$3 || exit
|
||||
env > env.log
|
||||
echo "start inferring for device $DEVICE_ID"
|
||||
python eval.py \
|
||||
--dataset=$DATASET \
|
||||
--checkpoint_path=$CHECKPOINT_PATH \
|
||||
--device_id=$3 > log.txt 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Bbox utils"""
|
||||
|
||||
import math
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
from .config import config
|
||||
|
||||
class GeneratDefaultBoxes():
|
||||
"""
|
||||
Generate Default boxes for SSD, follows the order of (W, H, archor_sizes).
|
||||
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w].
|
||||
`self.default_boxes_tlbr` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].
|
||||
"""
|
||||
def __init__(self):
|
||||
fk = config.img_shape[0] / np.array(config.steps)
|
||||
scale_rate = (config.max_scale - config.min_scale) / (len(config.num_default) - 1)
|
||||
scales = [config.min_scale + scale_rate * i for i in range(len(config.num_default))] + [1.0]
|
||||
self.default_boxes = []
|
||||
for idex, feature_size in enumerate(config.feature_size):
|
||||
sk1 = scales[idex]
|
||||
sk2 = scales[idex + 1]
|
||||
sk3 = math.sqrt(sk1 * sk2)
|
||||
if idex == 0 and not config.aspect_ratios[idex]:
|
||||
w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2)
|
||||
all_sizes = [(0.1, 0.1), (w, h), (h, w)]
|
||||
else:
|
||||
all_sizes = [(sk1, sk1)]
|
||||
for aspect_ratio in config.aspect_ratios[idex]:
|
||||
w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio)
|
||||
all_sizes.append((w, h))
|
||||
all_sizes.append((h, w))
|
||||
all_sizes.append((sk3, sk3))
|
||||
|
||||
assert len(all_sizes) == config.num_default[idex]
|
||||
|
||||
for i, j in it.product(range(feature_size), repeat=2):
|
||||
for w, h in all_sizes:
|
||||
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
|
||||
self.default_boxes.append([cy, cx, h, w])
|
||||
|
||||
def to_tlbr(cy, cx, h, w):
|
||||
return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2
|
||||
|
||||
# For IoU calculation
|
||||
self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype='float32')
|
||||
self.default_boxes = np.array(self.default_boxes, dtype='float32')
|
||||
|
||||
default_boxes_tlbr = GeneratDefaultBoxes().default_boxes_tlbr
|
||||
default_boxes = GeneratDefaultBoxes().default_boxes
|
||||
y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1)
|
||||
vol_anchors = (x2 - x1) * (y2 - y1)
|
||||
matching_threshold = config.match_threshold
|
||||
|
||||
|
||||
def ssd_bboxes_encode(boxes):
|
||||
"""
|
||||
Labels anchors with ground truth inputs.
|
||||
|
||||
Args:
|
||||
boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls].
|
||||
|
||||
Returns:
|
||||
gt_loc: location ground truth with shape [num_anchors, 4].
|
||||
gt_label: class ground truth with shape [num_anchors, 1].
|
||||
num_matched_boxes: number of positives in an image.
|
||||
"""
|
||||
|
||||
def jaccard_with_anchors(bbox):
|
||||
"""Compute jaccard score a box and the anchors."""
|
||||
# Intersection bbox and volume.
|
||||
ymin = np.maximum(y1, bbox[0])
|
||||
xmin = np.maximum(x1, bbox[1])
|
||||
ymax = np.minimum(y2, bbox[2])
|
||||
xmax = np.minimum(x2, bbox[3])
|
||||
w = np.maximum(xmax - xmin, 0.)
|
||||
h = np.maximum(ymax - ymin, 0.)
|
||||
|
||||
# Volumes.
|
||||
inter_vol = h * w
|
||||
union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol
|
||||
jaccard = inter_vol / union_vol
|
||||
return np.squeeze(jaccard)
|
||||
|
||||
pre_scores = np.zeros((config.num_ssd_boxes), dtype=np.float32)
|
||||
t_boxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32)
|
||||
t_label = np.zeros((config.num_ssd_boxes), dtype=np.int64)
|
||||
for bbox in boxes:
|
||||
label = int(bbox[4])
|
||||
scores = jaccard_with_anchors(bbox)
|
||||
idx = np.argmax(scores)
|
||||
scores[idx] = 2.0
|
||||
mask = (scores > matching_threshold)
|
||||
mask = mask & (scores > pre_scores)
|
||||
pre_scores = np.maximum(pre_scores, scores * mask)
|
||||
t_label = mask * label + (1 - mask) * t_label
|
||||
for i in range(4):
|
||||
t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]
|
||||
|
||||
index = np.nonzero(t_label)
|
||||
|
||||
# Transform to tlbr.
|
||||
bboxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32)
|
||||
bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
|
||||
bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]
|
||||
|
||||
# Encode features.
|
||||
bboxes_t = bboxes[index]
|
||||
default_boxes_t = default_boxes[index]
|
||||
bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.prior_scaling[0])
|
||||
tmp = np.maximum(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4], 0.000001)
|
||||
bboxes_t[:, 2:4] = np.log(tmp) / config.prior_scaling[1]
|
||||
bboxes[index] = bboxes_t
|
||||
|
||||
num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
|
||||
return bboxes, t_label.astype(np.int32), num_match
|
||||
|
||||
|
||||
def ssd_bboxes_decode(boxes):
|
||||
"""Decode predict boxes to [y, x, h, w]"""
|
||||
boxes_t = boxes.copy()
|
||||
default_boxes_t = default_boxes.copy()
|
||||
boxes_t[:, :2] = boxes_t[:, :2] * config.prior_scaling[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2]
|
||||
boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.prior_scaling[1]) * default_boxes_t[:, 2:4]
|
||||
|
||||
bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32)
|
||||
|
||||
bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2
|
||||
bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2
|
||||
|
||||
return np.clip(bboxes, 0, 1)
|
||||
|
||||
|
||||
def intersect(box_a, box_b):
|
||||
"""Compute the intersect of two sets of boxes."""
|
||||
max_yx = np.minimum(box_a[:, 2:4], box_b[2:4])
|
||||
min_yx = np.maximum(box_a[:, :2], box_b[:2])
|
||||
inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf)
|
||||
return inter[:, 0] * inter[:, 1]
|
||||
|
||||
|
||||
def jaccard_numpy(box_a, box_b):
|
||||
"""Compute the jaccard overlap of two sets of boxes."""
|
||||
inter = intersect(box_a, box_b)
|
||||
area_a = ((box_a[:, 2] - box_a[:, 0]) *
|
||||
(box_a[:, 3] - box_a[:, 1]))
|
||||
area_b = ((box_b[2] - box_b[0]) *
|
||||
(box_b[3] - box_b[1]))
|
||||
union = area_a + area_b - inter
|
||||
return inter / union
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
|
||||
"""Config parameters for SSD models."""
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"model": "ssd320",
|
||||
"img_shape": [320, 320],
|
||||
"num_ssd_boxes": 2034,
|
||||
"neg_pre_positive": 3,
|
||||
"match_threshold": 0.5,
|
||||
"nms_threshold": 0.6,
|
||||
"min_score": 0.1,
|
||||
"max_boxes": 100,
|
||||
|
||||
# learing rate settings
|
||||
"lr_init": 0.001,
|
||||
"lr_end_rate": 0.001,
|
||||
"warmup_epochs": 2,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1.5e-4,
|
||||
|
||||
# network
|
||||
"num_default": [3, 6, 6, 6, 6, 6],
|
||||
"extras_in_channels": [256, 576, 1280, 512, 256, 256],
|
||||
"extras_out_channels": [576, 1280, 512, 256, 256, 128],
|
||||
"extras_strides": [1, 1, 2, 2, 2, 2],
|
||||
"extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25],
|
||||
"feature_size": [20, 10, 5, 3, 2, 1],
|
||||
"min_scale": 0.2,
|
||||
"max_scale": 0.95,
|
||||
"aspect_ratios": [(), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)],
|
||||
"steps": (16, 32, 64, 108, 160, 320),
|
||||
"prior_scaling": (0.1, 0.2),
|
||||
"gamma": 2.0,
|
||||
"alpha": 0.75,
|
||||
|
||||
# `mindrecord_dir` and `coco_root` are better to use absolute path.
|
||||
"feature_extractor_base_param": "",
|
||||
"mindrecord_dir": "/cache/MindRecord_COCO",
|
||||
"coco_root": "/cache/coco2017",
|
||||
"train_data_type": "train2017",
|
||||
"val_data_type": "val2017",
|
||||
"instances_set": "annotations/instances_{}.json",
|
||||
"classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush'),
|
||||
"num_classes": 81,
|
||||
# The annotation.json position of voc validation dataset.
|
||||
"voc_json": "annotations/voc_instances_val.json",
|
||||
# voc original dataset.
|
||||
"voc_root": "/data/voc_dataset",
|
||||
# if coco or voc used, `image_dir` and `anno_path` are useless.
|
||||
"image_dir": "",
|
||||
"anno_path": ""
|
||||
})
|
|
@ -0,0 +1,454 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SSD dataset"""
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
import json
|
||||
import xml.etree.ElementTree as et
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from .config import config
|
||||
from .box_utils import jaccard_numpy, ssd_bboxes_encode
|
||||
|
||||
|
||||
def _rand(a=0., b=1.):
|
||||
"""Generate random."""
|
||||
return np.random.rand() * (b - a) + a
|
||||
|
||||
|
||||
def get_imageId_from_fileName(filename, id_iter):
|
||||
"""Get imageID from fileName if fileName is int, else return id_iter."""
|
||||
filename = os.path.splitext(filename)[0]
|
||||
if filename.isdigit():
|
||||
return int(filename)
|
||||
return id_iter
|
||||
|
||||
|
||||
def random_sample_crop(image, boxes):
|
||||
"""Random Crop the image and boxes"""
|
||||
height, width, _ = image.shape
|
||||
min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9])
|
||||
|
||||
if min_iou is None:
|
||||
return image, boxes
|
||||
|
||||
# max trails (50)
|
||||
for _ in range(50):
|
||||
image_t = image
|
||||
|
||||
w = _rand(0.3, 1.0) * width
|
||||
h = _rand(0.3, 1.0) * height
|
||||
|
||||
# aspect ratio constraint b/t .5 & 2
|
||||
if h / w < 0.5 or h / w > 2:
|
||||
continue
|
||||
|
||||
left = _rand() * (width - w)
|
||||
top = _rand() * (height - h)
|
||||
|
||||
rect = np.array([int(top), int(left), int(top + h), int(left + w)])
|
||||
overlap = jaccard_numpy(boxes, rect)
|
||||
|
||||
# dropout some boxes
|
||||
drop_mask = overlap > 0
|
||||
if not drop_mask.any():
|
||||
continue
|
||||
|
||||
if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2):
|
||||
continue
|
||||
|
||||
image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :]
|
||||
|
||||
centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0
|
||||
|
||||
m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
|
||||
m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
|
||||
|
||||
# mask in that both m1 and m2 are true
|
||||
mask = m1 * m2 * drop_mask
|
||||
|
||||
# have any valid boxes? try again if not
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
# take only matching gt boxes
|
||||
boxes_t = boxes[mask, :].copy()
|
||||
|
||||
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2])
|
||||
boxes_t[:, :2] -= rect[:2]
|
||||
boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4])
|
||||
boxes_t[:, 2:4] -= rect[:2]
|
||||
|
||||
return image_t, boxes_t
|
||||
return image, boxes
|
||||
|
||||
|
||||
def preprocess_fn(img_id, image, box, is_training):
|
||||
"""Preprocess function for dataset."""
|
||||
cv2.setNumThreads(2)
|
||||
|
||||
def _infer_data(image, input_shape):
|
||||
img_h, img_w, _ = image.shape
|
||||
input_h, input_w = input_shape
|
||||
|
||||
image = cv2.resize(image, (input_w, input_h))
|
||||
|
||||
# When the channels of image is 1
|
||||
if len(image.shape) == 2:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
image = np.concatenate([image, image, image], axis=-1)
|
||||
|
||||
return img_id, image, np.array((img_h, img_w), np.float32)
|
||||
|
||||
def _data_aug(image, box, is_training, image_size=(300, 300)):
|
||||
"""Data augmentation function."""
|
||||
ih, iw, _ = image.shape
|
||||
h, w = image_size
|
||||
|
||||
if not is_training:
|
||||
return _infer_data(image, image_size)
|
||||
|
||||
# Random crop
|
||||
box = box.astype(np.float32)
|
||||
image, box = random_sample_crop(image, box)
|
||||
ih, iw, _ = image.shape
|
||||
|
||||
# Resize image
|
||||
image = cv2.resize(image, (w, h))
|
||||
|
||||
# Flip image or not
|
||||
flip = _rand() < .5
|
||||
if flip:
|
||||
image = cv2.flip(image, 1, dst=None)
|
||||
|
||||
# When the channels of image is 1
|
||||
if len(image.shape) == 2:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
image = np.concatenate([image, image, image], axis=-1)
|
||||
|
||||
box[:, [0, 2]] = box[:, [0, 2]] / ih
|
||||
box[:, [1, 3]] = box[:, [1, 3]] / iw
|
||||
|
||||
if flip:
|
||||
box[:, [1, 3]] = 1 - box[:, [3, 1]]
|
||||
|
||||
box, label, num_match = ssd_bboxes_encode(box)
|
||||
return image, box, label, num_match
|
||||
|
||||
return _data_aug(image, box, is_training, image_size=config.img_shape)
|
||||
|
||||
|
||||
def create_voc_label(is_training):
|
||||
"""Get image path and annotation from VOC."""
|
||||
voc_root = config.voc_root
|
||||
cls_map = {name: i for i, name in enumerate(config.classes)}
|
||||
sub_dir = 'train' if is_training else 'eval'
|
||||
voc_dir = os.path.join(voc_root, sub_dir)
|
||||
if not os.path.isdir(voc_dir):
|
||||
raise ValueError(f'Cannot find {sub_dir} dataset path.')
|
||||
|
||||
image_dir = anno_dir = voc_dir
|
||||
if os.path.isdir(os.path.join(voc_dir, 'Images')):
|
||||
image_dir = os.path.join(voc_dir, 'Images')
|
||||
if os.path.isdir(os.path.join(voc_dir, 'Annotations')):
|
||||
anno_dir = os.path.join(voc_dir, 'Annotations')
|
||||
|
||||
if not is_training:
|
||||
json_file = os.path.join(config.voc_root, config.voc_json)
|
||||
file_dir = os.path.split(json_file)[0]
|
||||
if not os.path.isdir(file_dir):
|
||||
os.makedirs(file_dir)
|
||||
json_dict = {"images": [], "type": "instances", "annotations": [],
|
||||
"categories": []}
|
||||
bnd_id = 1
|
||||
|
||||
image_files_dict = {}
|
||||
image_anno_dict = {}
|
||||
images = []
|
||||
id_iter = 0
|
||||
for anno_file in os.listdir(anno_dir):
|
||||
print(anno_file)
|
||||
if not anno_file.endswith('xml'):
|
||||
continue
|
||||
tree = et.parse(os.path.join(anno_dir, anno_file))
|
||||
root_node = tree.getroot()
|
||||
file_name = root_node.find('filename').text
|
||||
img_id = get_imageId_from_fileName(file_name, id_iter)
|
||||
id_iter += 1
|
||||
image_path = os.path.join(image_dir, file_name)
|
||||
print(image_path)
|
||||
if not os.path.isfile(image_path):
|
||||
print(f'Cannot find image {file_name} according to annotations.')
|
||||
continue
|
||||
|
||||
labels = []
|
||||
for obj in root_node.iter('object'):
|
||||
cls_name = obj.find('name').text
|
||||
if cls_name not in cls_map:
|
||||
print(f'Label "{cls_name}" not in "{config.classes}"')
|
||||
continue
|
||||
bnd_box = obj.find('bndbox')
|
||||
x_min = int(float(bnd_box.find('xmin').text)) - 1
|
||||
y_min = int(float(bnd_box.find('ymin').text)) - 1
|
||||
x_max = int(float(bnd_box.find('xmax').text)) - 1
|
||||
y_max = int(float(bnd_box.find('ymax').text)) - 1
|
||||
labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]])
|
||||
|
||||
if not is_training:
|
||||
o_width = abs(x_max - x_min)
|
||||
o_height = abs(y_max - y_min)
|
||||
ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id': \
|
||||
img_id, 'bbox': [x_min, y_min, o_width, o_height], \
|
||||
'category_id': cls_map[cls_name], 'id': bnd_id, \
|
||||
'ignore': 0, \
|
||||
'segmentation': []}
|
||||
json_dict['annotations'].append(ann)
|
||||
bnd_id = bnd_id + 1
|
||||
|
||||
if labels:
|
||||
images.append(img_id)
|
||||
image_files_dict[img_id] = image_path
|
||||
image_anno_dict[img_id] = np.array(labels)
|
||||
|
||||
if not is_training:
|
||||
size = root_node.find("size")
|
||||
width = int(size.find('width').text)
|
||||
height = int(size.find('height').text)
|
||||
image = {'file_name': file_name, 'height': height, 'width': width,
|
||||
'id': img_id}
|
||||
json_dict['images'].append(image)
|
||||
|
||||
if not is_training:
|
||||
for cls_name, cid in cls_map.items():
|
||||
cat = {'supercategory': 'none', 'id': cid, 'name': cls_name}
|
||||
json_dict['categories'].append(cat)
|
||||
json_fp = open(json_file, 'w')
|
||||
json_str = json.dumps(json_dict)
|
||||
json_fp.write(json_str)
|
||||
json_fp.close()
|
||||
|
||||
return images, image_files_dict, image_anno_dict
|
||||
|
||||
|
||||
def create_coco_label(is_training):
|
||||
"""Get image path and annotation from COCO."""
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
coco_root = config.coco_root
|
||||
data_type = config.val_data_type
|
||||
if is_training:
|
||||
data_type = config.train_data_type
|
||||
|
||||
# Classes need to train or test.
|
||||
train_cls = config.classes
|
||||
train_cls_dict = {}
|
||||
for i, cls in enumerate(train_cls):
|
||||
train_cls_dict[cls] = i
|
||||
|
||||
anno_json = os.path.join(coco_root, config.instances_set.format(data_type))
|
||||
|
||||
coco = COCO(anno_json)
|
||||
classs_dict = {}
|
||||
cat_ids = coco.loadCats(coco.getCatIds())
|
||||
for cat in cat_ids:
|
||||
classs_dict[cat["id"]] = cat["name"]
|
||||
|
||||
image_ids = coco.getImgIds()
|
||||
images = []
|
||||
image_path_dict = {}
|
||||
image_anno_dict = {}
|
||||
|
||||
for img_id in image_ids:
|
||||
image_info = coco.loadImgs(img_id)
|
||||
file_name = image_info[0]["file_name"]
|
||||
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||
anno = coco.loadAnns(anno_ids)
|
||||
image_path = os.path.join(coco_root, data_type, file_name)
|
||||
annos = []
|
||||
iscrowd = False
|
||||
for label in anno:
|
||||
bbox = label["bbox"]
|
||||
class_name = classs_dict[label["category_id"]]
|
||||
iscrowd = iscrowd or label["iscrowd"]
|
||||
if class_name in train_cls:
|
||||
x_min, x_max = bbox[0], bbox[0] + bbox[2]
|
||||
y_min, y_max = bbox[1], bbox[1] + bbox[3]
|
||||
annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]])
|
||||
|
||||
if not is_training and iscrowd:
|
||||
continue
|
||||
if len(annos) >= 1:
|
||||
images.append(img_id)
|
||||
image_path_dict[img_id] = image_path
|
||||
image_anno_dict[img_id] = np.array(annos)
|
||||
|
||||
return images, image_path_dict, image_anno_dict
|
||||
|
||||
|
||||
def anno_parser(annos_str):
|
||||
"""Parse annotation from string to list."""
|
||||
annos = []
|
||||
for anno_str in annos_str:
|
||||
anno = list(map(int, anno_str.strip().split(',')))
|
||||
annos.append(anno)
|
||||
return annos
|
||||
|
||||
|
||||
def filter_valid_data(image_dir, anno_path):
|
||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
||||
images = []
|
||||
image_path_dict = {}
|
||||
image_anno_dict = {}
|
||||
if not os.path.isdir(image_dir):
|
||||
raise RuntimeError("Path given is not valid.")
|
||||
if not os.path.isfile(anno_path):
|
||||
raise RuntimeError("Annotation file is not valid.")
|
||||
|
||||
with open(anno_path, "rb") as f:
|
||||
lines = f.readlines()
|
||||
for img_id, line in enumerate(lines):
|
||||
line_str = line.decode("utf-8").strip()
|
||||
line_split = str(line_str).split(' ')
|
||||
file_name = line_split[0]
|
||||
image_path = os.path.join(image_dir, file_name)
|
||||
if os.path.isfile(image_path):
|
||||
images.append(img_id)
|
||||
image_path_dict[img_id] = image_path
|
||||
image_anno_dict[img_id] = anno_parser(line_split[1:])
|
||||
|
||||
return images, image_path_dict, image_anno_dict
|
||||
|
||||
|
||||
def voc_data_to_mindrecord(mindrecord_dir, is_training, prefix="ssd.mindrecord", file_num=8):
|
||||
"""Create MindRecord file by image_dir and anno_path."""
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
images, image_path_dict, image_anno_dict = create_voc_label(is_training)
|
||||
|
||||
ssd_json = {
|
||||
"img_id": {"type": "int32", "shape": [1]},
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 5]},
|
||||
}
|
||||
writer.add_schema(ssd_json, "ssd_json")
|
||||
|
||||
for img_id in images:
|
||||
image_path = image_path_dict[img_id]
|
||||
with open(image_path, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[img_id], dtype=np.int32)
|
||||
img_id = np.array([img_id], dtype=np.int32)
|
||||
row = {"img_id": img_id, "image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8):
|
||||
"""Create MindRecord file."""
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
if dataset == "coco":
|
||||
images, image_path_dict, image_anno_dict = create_coco_label(is_training)
|
||||
else:
|
||||
images, image_path_dict, image_anno_dict = filter_valid_data(config.image_dir, config.anno_path)
|
||||
|
||||
ssd_json = {
|
||||
"img_id": {"type": "int32", "shape": [1]},
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 5]},
|
||||
}
|
||||
writer.add_schema(ssd_json, "ssd_json")
|
||||
|
||||
for img_id in images:
|
||||
image_path = image_path_dict[img_id]
|
||||
with open(image_path, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[img_id], dtype=np.int32)
|
||||
img_id = np.array([img_id], dtype=np.int32)
|
||||
row = {"img_id": img_id, "image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0,
|
||||
is_training=True, num_parallel_workers=64, use_multiprocessing=True):
|
||||
"""Create SSD dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num,
|
||||
shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)
|
||||
decode = C.Decode()
|
||||
ds = ds.map(operations=decode, input_columns=["image"])
|
||||
change_swap_op = C.HWC2CHW()
|
||||
normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
||||
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
|
||||
color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||
compose_map_func = (lambda img_id, image, annotation: preprocess_fn(img_id, image, annotation, is_training))
|
||||
if is_training:
|
||||
output_columns = ["image", "box", "label", "num_match"]
|
||||
trans = [color_adjust_op, normalize_op, change_swap_op]
|
||||
else:
|
||||
output_columns = ["img_id", "image", "image_shape"]
|
||||
trans = [normalize_op, change_swap_op]
|
||||
ds = ds.map(operations=compose_map_func, input_columns=["img_id", "image", "annotation"],
|
||||
output_columns=output_columns, column_order=output_columns,
|
||||
python_multiprocessing=use_multiprocessing,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=trans, input_columns=["image"], python_multiprocessing=use_multiprocessing,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
||||
|
||||
|
||||
def create_mindrecord(dataset="coco", prefix="ssd.mindrecord", is_training=True):
|
||||
"""create mindrecord file"""
|
||||
print("Start create dataset!")
|
||||
# It will generate mindrecord file in config.mindrecord_dir,
|
||||
# and the file name is ssd.mindrecord0, 1, ... file_num.
|
||||
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if dataset == "coco":
|
||||
if os.path.isdir(config.coco_root):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("coco", is_training, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("coco_root not exits.")
|
||||
elif dataset == "voc":
|
||||
if os.path.isdir(config.voc_root):
|
||||
print("Create Mindrecord.")
|
||||
voc_data_to_mindrecord(mindrecord_dir, is_training, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("voc_root not exits.")
|
||||
else:
|
||||
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("other", is_training, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("image_dir or anno_path not exits.")
|
||||
return mindrecord_file
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Coco metrics utils"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from .config import config
|
||||
|
||||
|
||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||
"""Apply NMS to bboxes."""
|
||||
y1 = all_boxes[:, 0]
|
||||
x1 = all_boxes[:, 1]
|
||||
y2 = all_boxes[:, 2]
|
||||
x2 = all_boxes[:, 3]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
||||
order = all_scores.argsort()[::-1]
|
||||
keep = []
|
||||
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
if len(keep) >= max_boxes:
|
||||
break
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def metrics(pred_data, anno_json):
|
||||
"""Calculate mAP of predicted bboxes."""
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
num_classes = config.num_classes
|
||||
|
||||
#Classes need to train or test.
|
||||
val_cls = config.classes
|
||||
val_cls_dict = {}
|
||||
for i, cls in enumerate(val_cls):
|
||||
val_cls_dict[i] = cls
|
||||
coco_gt = COCO(anno_json)
|
||||
classs_dict = {}
|
||||
cat_ids = coco_gt.loadCats(coco_gt.getCatIds())
|
||||
for cat in cat_ids:
|
||||
classs_dict[cat["name"]] = cat["id"]
|
||||
|
||||
predictions = []
|
||||
img_ids = []
|
||||
|
||||
for sample in pred_data:
|
||||
pred_boxes = sample['boxes']
|
||||
box_scores = sample['box_scores']
|
||||
img_id = sample['img_id']
|
||||
h, w = sample['image_shape']
|
||||
|
||||
final_boxes = []
|
||||
final_label = []
|
||||
final_score = []
|
||||
img_ids.append(img_id)
|
||||
|
||||
for c in range(1, num_classes):
|
||||
class_box_scores = box_scores[:, c]
|
||||
score_mask = class_box_scores > config.min_score
|
||||
class_box_scores = class_box_scores[score_mask]
|
||||
class_boxes = pred_boxes[score_mask] * [h, w, h, w]
|
||||
|
||||
if score_mask.any():
|
||||
nms_index = apply_nms(class_boxes, class_box_scores, config.nms_threshold, config.max_boxes)
|
||||
class_boxes = class_boxes[nms_index]
|
||||
class_box_scores = class_box_scores[nms_index]
|
||||
|
||||
final_boxes += class_boxes.tolist()
|
||||
final_score += class_box_scores.tolist()
|
||||
final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores)
|
||||
|
||||
for loc, label, score in zip(final_boxes, final_label, final_score):
|
||||
res = {}
|
||||
res['image_id'] = img_id
|
||||
res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]]
|
||||
res['score'] = score
|
||||
res['category_id'] = label
|
||||
predictions.append(res)
|
||||
with open('predictions.json', 'w') as f:
|
||||
json.dump(predictions, f)
|
||||
|
||||
coco_dt = coco_gt.loadRes('predictions.json')
|
||||
E = COCOeval(coco_gt, coco_dt, iouType='bbox')
|
||||
E.params.imgIds = img_ids
|
||||
E.evaluate()
|
||||
E.accumulate()
|
||||
E.summarize()
|
||||
return E.stats[0]
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parameters utils"""
|
||||
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
|
||||
def init_net_param(network, initialize_mode='TruncatedNormal'):
|
||||
"""Init the parameters in net."""
|
||||
params = network.trainable_params()
|
||||
for p in params:
|
||||
if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
if initialize_mode == 'TruncatedNormal':
|
||||
p.set_data(initializer(TruncatedNormal(0.02), p.data.shape, p.data.dtype))
|
||||
else:
|
||||
p.set_data(initialize_mode, p.data.shape, p.data.dtype)
|
||||
|
||||
|
||||
def load_backbone_params(network, param_dict):
|
||||
"""Init the parameters from pre-train model, default is mobilenetv2."""
|
||||
for _, param in network.parameters_and_names():
|
||||
param_name = param.name.replace('network.backbone.', '')
|
||||
name_split = param_name.split('.')
|
||||
if 'features_1' in param_name:
|
||||
param_name = param_name.replace('features_1', 'features')
|
||||
if 'features_2' in param_name:
|
||||
param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:])
|
||||
if param_name in param_dict:
|
||||
param.set_data(param_dict[param_name].data)
|
||||
|
||||
def filter_checkpoint_parameter(param_dict):
|
||||
"""remove useless parameters"""
|
||||
for key in list(param_dict.keys()):
|
||||
if 'multi_loc_layers' in key or 'multi_cls_layers' in key:
|
||||
del param_dict[key]
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Learning rate schedule"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): total steps of the training
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(float): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_end + \
|
||||
(lr_max - lr_end) * \
|
||||
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
|
@ -0,0 +1,499 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SSD net based MobilenetV2."""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""nsures that all layers have a channel number that is divisible by 8."""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'):
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
|
||||
padding=0, pad_mode=pad_mod, has_bias=True)
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.97,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0):
|
||||
in_channels = in_channel
|
||||
out_channels = in_channel
|
||||
conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same',
|
||||
padding=pad)
|
||||
conv2 = _conv2d(in_channel, out_channel, kernel_size=1)
|
||||
return nn.SequentialCell([conv1, _bn(in_channel), nn.ReLU6(), conv2])
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Cell):
|
||||
"""
|
||||
Convolution/Depthwise fused with Batchnorm and ReLU block definition.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
out_planes (int): Output channel.
|
||||
kernel_size (int): Input kernel size.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
|
||||
shared_conv(Cell): Use the weight shared conv, default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
|
||||
"""
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, shared_conv=None):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = 0
|
||||
in_channels = in_planes
|
||||
out_channels = out_planes
|
||||
if shared_conv is None:
|
||||
if groups == 1:
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=padding)
|
||||
else:
|
||||
out_channels = in_planes
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same',
|
||||
padding=padding, group=in_channels)
|
||||
layers = [conv, _bn(out_planes), nn.ReLU6()]
|
||||
else:
|
||||
layers = [shared_conv, _bn(out_planes), nn.ReLU6()]
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
output = self.features(x)
|
||||
return output
|
||||
|
||||
|
||||
class InvertedResidual(nn.Cell):
|
||||
"""
|
||||
Residual block definition.
|
||||
|
||||
Args:
|
||||
inp (int): Input channel.
|
||||
oup (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
expand_ratio (int): expand ration of input channel
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlock(3, 256, 1, 1)
|
||||
"""
|
||||
def __init__(self, inp, oup, stride, expand_ratio, last_relu=False):
|
||||
super(InvertedResidual, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False),
|
||||
_bn(oup),
|
||||
])
|
||||
self.conv = nn.SequentialCell(layers)
|
||||
self.cast = P.Cast()
|
||||
self.last_relu = last_relu
|
||||
self.relu = nn.ReLU6()
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
x = self.conv(x)
|
||||
if self.use_res_connect:
|
||||
x = identity + x
|
||||
if self.last_relu:
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class FlattenConcat(nn.Cell):
|
||||
"""
|
||||
Concatenate predictions into a single tensor.
|
||||
|
||||
Args:
|
||||
config (dict): The default config of SSD.
|
||||
|
||||
Returns:
|
||||
Tensor, flatten predictions.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(FlattenConcat, self).__init__()
|
||||
self.num_ssd_boxes = config.num_ssd_boxes
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.transpose = P.Transpose()
|
||||
def construct(self, inputs):
|
||||
output = ()
|
||||
batch_size = F.shape(inputs[0])[0]
|
||||
for x in inputs:
|
||||
x = self.transpose(x, (0, 2, 3, 1))
|
||||
output += (F.reshape(x, (batch_size, -1)),)
|
||||
res = self.concat(output)
|
||||
return F.reshape(res, (batch_size, self.num_ssd_boxes, -1))
|
||||
|
||||
|
||||
class MultiBox(nn.Cell):
|
||||
"""
|
||||
Multibox conv layers. Each multibox layer contains class conf scores and localization predictions.
|
||||
|
||||
Args:
|
||||
config (dict): The default config of SSD.
|
||||
|
||||
Returns:
|
||||
Tensor, localization predictions.
|
||||
Tensor, class conf scores.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(MultiBox, self).__init__()
|
||||
num_classes = config.num_classes
|
||||
out_channels = config.extras_out_channels
|
||||
num_default = config.num_default
|
||||
|
||||
loc_layers = []
|
||||
cls_layers = []
|
||||
for k, out_channel in enumerate(out_channels):
|
||||
loc_layers += [_last_conv2d(out_channel, 4 * num_default[k],
|
||||
kernel_size=3, stride=1, pad_mod='same', pad=0)]
|
||||
cls_layers += [_last_conv2d(out_channel, num_classes * num_default[k],
|
||||
kernel_size=3, stride=1, pad_mod='same', pad=0)]
|
||||
|
||||
self.multi_loc_layers = nn.layer.CellList(loc_layers)
|
||||
self.multi_cls_layers = nn.layer.CellList(cls_layers)
|
||||
self.flatten_concat = FlattenConcat(config)
|
||||
|
||||
def construct(self, inputs):
|
||||
loc_outputs = ()
|
||||
cls_outputs = ()
|
||||
for i in range(len(self.multi_loc_layers)):
|
||||
loc_outputs += (self.multi_loc_layers[i](inputs[i]),)
|
||||
cls_outputs += (self.multi_cls_layers[i](inputs[i]),)
|
||||
return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs)
|
||||
|
||||
|
||||
|
||||
|
||||
class SSD320(nn.Cell):
|
||||
"""
|
||||
SSD320 Network. Default backbone is resnet34.
|
||||
|
||||
Args:
|
||||
backbone (Cell): Backbone Network.
|
||||
config (dict): The default config of SSD.
|
||||
|
||||
Returns:
|
||||
Tensor, localization predictions.
|
||||
Tensor, class conf scores.
|
||||
|
||||
Examples:backbone
|
||||
SSD320(backbone=resnet34(num_classes=None),
|
||||
config=config).
|
||||
"""
|
||||
def __init__(self, backbone, config, is_training=True):
|
||||
super(SSD320, self).__init__()
|
||||
|
||||
self.backbone = backbone
|
||||
in_channels = config.extras_in_channels
|
||||
out_channels = config.extras_out_channels
|
||||
ratios = config.extras_ratio
|
||||
strides = config.extras_strides
|
||||
residual_list = []
|
||||
for i in range(2, len(in_channels)):
|
||||
residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i],
|
||||
expand_ratio=ratios[i], last_relu=True)
|
||||
residual_list.append(residual)
|
||||
self.multi_residual = nn.layer.CellList(residual_list)
|
||||
self.multi_box = MultiBox(config)
|
||||
self.is_training = is_training
|
||||
if not is_training:
|
||||
self.activation = P.Sigmoid()
|
||||
|
||||
def construct(self, x):
|
||||
"""return pred_loc and pred_label"""
|
||||
layer_out_13, output = self.backbone(x)
|
||||
multi_feature = (layer_out_13, output)
|
||||
feature = output
|
||||
for residual in self.multi_residual:
|
||||
feature = residual(feature)
|
||||
multi_feature += (feature,)
|
||||
pred_loc, pred_label = self.multi_box(multi_feature)
|
||||
if not self.is_training:
|
||||
pred_label = self.activation(pred_label)
|
||||
pred_loc = F.cast(pred_loc, mstype.float32)
|
||||
pred_label = F.cast(pred_label, mstype.float32)
|
||||
return pred_loc, pred_label
|
||||
|
||||
|
||||
|
||||
class SigmoidFocalClassificationLoss(nn.Cell):
|
||||
""""
|
||||
Sigmoid focal-loss for classification.
|
||||
|
||||
Args:
|
||||
gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0
|
||||
alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25
|
||||
|
||||
Returns:
|
||||
Tensor, the focal loss.
|
||||
"""
|
||||
def __init__(self, gamma=2.0, alpha=0.25):
|
||||
super(SigmoidFocalClassificationLoss, self).__init__()
|
||||
self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.pow = P.Pow()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
|
||||
def construct(self, logits, label):
|
||||
label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value)
|
||||
sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label)
|
||||
sigmoid = self.sigmoid(logits)
|
||||
label = F.cast(label, mstype.float32)
|
||||
p_t = label * sigmoid + (1 - label) * (1 - sigmoid)
|
||||
modulating_factor = self.pow(1 - p_t, self.gamma)
|
||||
alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
|
||||
focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy
|
||||
return focal_loss
|
||||
|
||||
|
||||
class SSDWithLossCell(nn.Cell):
|
||||
""""
|
||||
Provide SSD training loss through network.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
config (dict): SSD config.
|
||||
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
def __init__(self, network, config):
|
||||
super(SSDWithLossCell, self).__init__()
|
||||
self.network = network
|
||||
self.less = P.Less()
|
||||
self.tile = P.Tile()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.class_loss = SigmoidFocalClassificationLoss(config.gamma, config.alpha)
|
||||
self.loc_loss = nn.SmoothL1Loss()
|
||||
|
||||
def construct(self, x, gt_loc, gt_label, num_matched_boxes):
|
||||
"""get loss"""
|
||||
pred_loc, pred_label = self.network(x)
|
||||
mask = F.cast(self.less(0, gt_label), mstype.float32)
|
||||
num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32))
|
||||
|
||||
# Localization Loss
|
||||
mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4))
|
||||
smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc
|
||||
loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1)
|
||||
|
||||
# Classification Loss
|
||||
loss_cls = self.class_loss(pred_label, gt_label)
|
||||
loss_cls = self.reduce_sum(loss_cls, (1, 2))
|
||||
|
||||
return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes)
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * P.Reciprocal()(scale)
|
||||
|
||||
|
||||
class TrainingWrapper(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of SSD network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
use_global_nrom(bool): Whether apply global norm before optimizer. Default: False
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0, use_global_norm=False):
|
||||
super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = ms.ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
self.use_global_norm = use_global_norm
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
if auto_parallel_context().get_device_num_is_set():
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, *args):
|
||||
"""opt"""
|
||||
weights = self.weights
|
||||
loss = self.network(*args)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*args, sens)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
if self.use_global_norm:
|
||||
grads = self.hyper_map(F.partial(grad_scale, F.scalar_to_array(self.sens)), grads)
|
||||
grads = C.clip_by_global_norm(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class SSDWithMobileNetV2(nn.Cell):
|
||||
"""
|
||||
MobileNetV2 architecture for SSD backbone.
|
||||
|
||||
Args:
|
||||
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
|
||||
inverted_residual_setting (list): Inverted residual settings. Default is None
|
||||
round_nearest (list): Channel round to. Default is 8
|
||||
Returns:
|
||||
Tensor, the 13th feature after ConvBNReLU in MobileNetV2.
|
||||
Tensor, the last feature in MobileNetV2.
|
||||
|
||||
Examples:
|
||||
>>> SSDWithMobileNetV2()
|
||||
"""
|
||||
def __init__(self, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
|
||||
super(SSDWithMobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
if len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
#building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features = [ConvBNReLU(3, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
layer_index = 0
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
if layer_index == 13:
|
||||
hidden_dim = int(round(input_channel * t))
|
||||
self.expand_layer_conv_13 = ConvBNReLU(input_channel, hidden_dim, kernel_size=1)
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
layer_index += 1
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
|
||||
|
||||
self.features_1 = nn.SequentialCell(features[:14])
|
||||
self.features_2 = nn.SequentialCell(features[14:])
|
||||
|
||||
def construct(self, x):
|
||||
out = self.features_1(x)
|
||||
expand_layer_conv_13 = self.expand_layer_conv_13(out)
|
||||
out = self.features_2(out)
|
||||
return expand_layer_conv_13, out
|
||||
|
||||
def get_out_channels(self):
|
||||
return self.last_channel
|
||||
|
||||
|
||||
class SsdInferWithDecoder(nn.Cell):
|
||||
"""
|
||||
SSD Infer wrapper to decode the bbox locations.
|
||||
|
||||
Args:
|
||||
network (Cell): the origin ssd infer network without bbox decoder.
|
||||
default_boxes (Tensor): the default_boxes from anchor generator
|
||||
config (dict): ssd config
|
||||
Returns:
|
||||
Tensor, the locations for bbox after decoder representing (y0,x0,y1,x1)
|
||||
Tensor, the prediction labels.
|
||||
|
||||
"""
|
||||
def __init__(self, network, default_boxes, config):
|
||||
super(SsdInferWithDecoder, self).__init__()
|
||||
self.network = network
|
||||
self.default_boxes = default_boxes
|
||||
self.prior_scaling_xy = config.prior_scaling[0]
|
||||
self.prior_scaling_wh = config.prior_scaling[1]
|
||||
|
||||
def construct(self, x):
|
||||
"""get pred_xy and pred_label"""
|
||||
pred_loc, pred_label = self.network(x)
|
||||
|
||||
default_bbox_xy = self.default_boxes[..., :2]
|
||||
default_bbox_wh = self.default_boxes[..., 2:]
|
||||
pred_xy = pred_loc[..., :2] * self.prior_scaling_xy * default_bbox_wh + default_bbox_xy
|
||||
pred_wh = P.Exp()(pred_loc[..., 2:] * self.prior_scaling_wh) * default_bbox_wh
|
||||
|
||||
pred_xy_0 = pred_xy - pred_wh / 2.0
|
||||
pred_xy_1 = pred_xy + pred_wh / 2.0
|
||||
pred_xy = P.Concat(-1)((pred_xy_0, pred_xy_1))
|
||||
pred_xy = P.Maximum()(pred_xy, 0)
|
||||
pred_xy = P.Minimum()(pred_xy, 1)
|
||||
return pred_xy, pred_label
|
||||
|
||||
def ssd_mobilenet_v2(**kwargs):
|
||||
return SSDWithMobileNetV2(**kwargs)
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Train SSD and get checkpoint files."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from src.ssd import SSD320, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
|
||||
from src.config import config
|
||||
from src.dataset import create_ssd_dataset, create_mindrecord
|
||||
from src.lr_schedule import get_lr
|
||||
from src.init_params import init_net_param, filter_checkpoint_parameter
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def get_args():
|
||||
"""get arguments"""
|
||||
parser = argparse.ArgumentParser(description="SSD training")
|
||||
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
|
||||
help="run platform.")
|
||||
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False,
|
||||
help="If set it true, only create Mindrecord, default is False.")
|
||||
parser.add_argument("--distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is False.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--lr", type=float, default=0.05, help="Learning rate, default is 0.05.")
|
||||
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
parser.add_argument('--modelarts_mode', type=ast.literal_eval, default=False,
|
||||
help='train on modelarts or not, default is False')
|
||||
parser.add_argument('--mindrecord_mode', type=str, default="mindrecord", choices=("coco", "mindrecord"),
|
||||
help='type of data, default is mindrecord')
|
||||
parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
||||
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
|
||||
parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 10.")
|
||||
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
|
||||
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
|
||||
help="Filter head weight parameters, default is False.")
|
||||
parser.add_argument('--freeze_layer', type=str, default="none", choices=["none", "backbone"],
|
||||
help="freeze the weights of network, support freeze the backbone's weights, "
|
||||
"default is not freezing.")
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
def main():
|
||||
args_opt = get_args()
|
||||
if args_opt.modelarts_mode:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=device_id)
|
||||
config.coco_root = os.path.join(config.coco_root, str(device_id))
|
||||
config.mindrecord_dir = os.path.join(config.mindrecord_dir, str(device_id))
|
||||
if args_opt.distribute:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
init()
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89])
|
||||
rank = get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
if args_opt.mindrecord_mode == "mindrecord":
|
||||
mox.file.copy_parallel(args_opt.data_url, config.mindrecord_dir)
|
||||
else:
|
||||
mox.file.copy_parallel(args_opt.data_url, config.coco_root)
|
||||
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform)
|
||||
if args_opt.distribute:
|
||||
if os.getenv("DEVICE_ID", "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
|
||||
device_num = args_opt.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
init()
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 89])
|
||||
rank = get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
mindrecord_file = create_mindrecord(args_opt.dataset, "ssd.mindrecord", True)
|
||||
if args_opt.only_create_dataset:
|
||||
if args_opt.modelarts_mode:
|
||||
mox.file.copy_parallel(config.mindrecord_dir, args_opt.train_url)
|
||||
return
|
||||
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0.
|
||||
dataset = create_ssd_dataset(mindrecord_file, repeat_num=1, batch_size=args_opt.batch_size,
|
||||
device_num=device_num, rank=rank)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
|
||||
backbone = ssd_mobilenet_v2()
|
||||
ssd = SSD320(backbone=backbone, config=config)
|
||||
net = SSDWithLossCell(ssd, config)
|
||||
net.to_float(mstype.float16)
|
||||
|
||||
init_net_param(net)
|
||||
|
||||
# checkpoint
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
|
||||
save_ckpt_path = './ckpt_' + str(rank) + '/'
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=save_ckpt_path, config=ckpt_config)
|
||||
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
if args_opt.filter_weight:
|
||||
filter_checkpoint_parameter(param_dict)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
if args_opt.freeze_layer == "backbone":
|
||||
for param in backbone.feature_1.trainable_params():
|
||||
param.requires_grad = False
|
||||
|
||||
lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size,
|
||||
lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr,
|
||||
warmup_epochs=config.warmup_epochs,
|
||||
total_epochs=args_opt.epoch_size,
|
||||
steps_per_epoch=dataset_size))
|
||||
|
||||
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
|
||||
config.momentum, config.weight_decay, loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
|
||||
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
|
||||
model = Model(net)
|
||||
dataset_sink_mode = False
|
||||
if args_opt.mode == "sink":
|
||||
print("In sink mode, one epoch return a loss.")
|
||||
dataset_sink_mode = True
|
||||
print("Start train SSD, the first epoch will be slower because of the graph compilation.")
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
|
||||
if args_opt.modelarts_mode:
|
||||
mox.file.copy_parallel(save_ckpt_path, args_opt.train_url)
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue