forked from mindspore-Ecosystem/mindspore
Add cspdarknet53 network to modelzoo
This commit is contained in:
parent
438cc78623
commit
55e1538ac6
|
@ -0,0 +1,305 @@
|
|||
# Contents
|
||||
|
||||
- [CSPDarkNet53 Description](#CSPDarkNet53-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [CSPDarkNet53 Description](#contents)
|
||||
|
||||
CSPDarkNet53 is a simple, highly modularized network architecture for image classification. It designs results in a homogeneous, multi-branch architecture that has only a few hyper-parameters to set in CSPDarkNet53.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1911.11929.pdf) Chien-Yao Wang, Hong-Yuan Mark Liao, Yueh-Hua Wu, Ping-Yang Chen, Jun-Wei Hsieh, and I-Hau Yeh. CSPNet: A new backbone that can enhance learning capability of cnn. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshop (CVPR Workshop), 2020. 2, 7
|
||||
|
||||
# [Model architecture](#contents)
|
||||
|
||||
The overall network architecture of CSPDarkNet53 is show below:
|
||||
|
||||
[Link](https://arxiv.org/pdf/1911.11929.pdf)
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Dataset used can refer to paper.
|
||||
|
||||
- Dataset size: 125G, 1250k colorful images in 1000 classes
|
||||
- Train: 120G, 1200k images
|
||||
- Test: 5G, 50k images
|
||||
- Data format: RGB images.
|
||||
- Note: Data will be processed in src/dataset.py
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## [Mixed Precision(Ascend)](#contents)
|
||||
|
||||
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||
|
||||
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
- Running local with Ascend
|
||||
|
||||
```python
|
||||
# run standalone training example with train.py
|
||||
python train.py --is_distributed=0 --data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
|
||||
# run distributed training example
|
||||
bash run_standalone_train.sh [DEVICE_ID] [DATA_DIR] (option)[PATH_CHECKPOINT]
|
||||
|
||||
# run distributed training example
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR] (option)[PATH_CHECKPOINT]
|
||||
|
||||
# run evaluation example with eval.py
|
||||
python eval.py --is_distributed=0 --per_batch_size=1 --pretrained=$PATH_CHECKPOINT --data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
|
||||
# run evaluation example
|
||||
bash run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT]
|
||||
```
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
Please follow the instructions in the link below:
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>
|
||||
|
||||
```bash
|
||||
# Train ImageNet 8p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "is_distributed=1" on default_config.yaml file.
|
||||
# Set "data_dir='/cache/data/ImageNet/train'" on default_config.yaml file.
|
||||
# (option)Set "checkpoint_url='s3://dir_to_pretrained/'" on default_config.yaml file.
|
||||
# (option)Set "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file.
|
||||
# (option)Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "is_distributed=1" on the website UI interface.
|
||||
# Add "data_dir='/cache/data/ImageNet/train'" on the website UI interface.
|
||||
# (option)Add "checkpoint_url='s3://dir_to_pretrained/'" on the website UI interface.
|
||||
# (option)Add "pretrained='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
||||
# (option)Add other parameters on the website UI interface.
|
||||
# (2) (option)Upload or copy your pretrained model to S3 bucket if pretrained is set.
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/cspdarknet53" on the website UI interface.
|
||||
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
#
|
||||
# Eval ImageNet 1p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "is_distributed=0" on default_config.yaml file.
|
||||
# Set "per_batch_size=1" on default_config.yaml file.
|
||||
# Set "data_dir='/cache/data/ImageNet/validation_preprocess'" on default_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_pretrained/'" on default_config.yaml file.
|
||||
# Set "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file.
|
||||
# (option)Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "is_distributed=1" on the website UI interface.
|
||||
# Add "per_batch_size=1" on the website UI interface.
|
||||
# Add "data_dir='/cache/data/ImageNet/validation_preprocess'" on the website UI interface.
|
||||
# Add "checkpoint_url='s3://dir_to_pretrained/'" on the website UI interface.
|
||||
# Add "pretrained='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
||||
# (option)Add other parameters on the website UI interface.
|
||||
# (2) Upload or copy your trained model to S3 bucket.
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/cspdarknet53" on the website UI interface.
|
||||
# (5) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
```
|
||||
|
||||
# [Script description](#contents)
|
||||
|
||||
## [Script and sample code](#contents)
|
||||
|
||||
```shell
|
||||
.
|
||||
└─cspdarknet53
|
||||
├─README.md
|
||||
├── model_utils
|
||||
├─__init__.py # init file
|
||||
├─config.py # Parse arguments
|
||||
├─device_adapter.py # Device adapter for ModelArts
|
||||
├─local_adapter.py # Local adapter
|
||||
├─moxing_adapter.py # Moxing adapter for ModelArts
|
||||
├─scripts
|
||||
├─run_standalone_train.sh # launch standalone training with ascend platform(1p)
|
||||
├─run_distribute_train.sh # launch distributed training with ascend platform(8p)
|
||||
└─run_eval.sh # launch evaluating with ascend platform
|
||||
├─src
|
||||
├─utils
|
||||
├─__init__.py # modeule init file
|
||||
├─auto_mixed_precision.py # Auto mixed precision
|
||||
├─custom_op.py # network operations
|
||||
├─logging.py # Custom logger
|
||||
├─optimizers_init.py # optimizer parameters
|
||||
├─sampler.py # choose samples from the dataset
|
||||
├─var_init.py # Initialize
|
||||
├─__init__.py # parameter configuration
|
||||
├─cspdarknet53.py # network definition
|
||||
├─dataset.py # data preprocessing
|
||||
├─head.py # common head architecture
|
||||
├─image_classification.py # Image classification
|
||||
├─loss.py # Customized CrossEntropy loss function
|
||||
├─lr_generator.py # learning rate generator
|
||||
├─mindspore_hub_conf.py # mindspore_hub_conf script
|
||||
├─default_config.yaml # Configurations
|
||||
├─eval.py # eval net
|
||||
├─export.py # convert checkpoint
|
||||
└─train.py # train net
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
```python
|
||||
Major parameters in default_config.yaml are:
|
||||
'data_dir' # dataset dir
|
||||
'pretrained' # checkpoint dir
|
||||
'is_distributed' # is distribute param
|
||||
'per_batch_size' # batch size each device
|
||||
'log_path' # save log file path
|
||||
```
|
||||
|
||||
## [Training process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
- Ascend:
|
||||
|
||||
```shell
|
||||
# distribute training(8p)
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR] (option)[PATH_CHECKPOINT]
|
||||
# standalone training
|
||||
bash run_standalone_train.sh [DEVICE_ID] [DATA_DIR] (option)[PATH_CHECKPOINT]
|
||||
```
|
||||
|
||||
> Notes: RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV3, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.
|
||||
>
|
||||
> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh`
|
||||
|
||||
### Launch
|
||||
|
||||
```python
|
||||
# training example
|
||||
python:
|
||||
python train.py --is_distributed=0 --pretrained=PATH_CHECKPOINT --data_dir=DATA_DIR > log.txt 2>&1 &
|
||||
|
||||
shell:
|
||||
# distribute training example(8p)
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR] (option)[PATH_CHECKPOINT]
|
||||
# standalone training example
|
||||
bash run_standalone_train.sh [DEVICE_ID] [DATA_DIR] (option)[PATH_CHECKPOINT]
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./log.txt`.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
- Ascend:
|
||||
|
||||
```shell
|
||||
bash run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT]
|
||||
```
|
||||
|
||||
### Launch
|
||||
|
||||
```python
|
||||
# eval example
|
||||
python:
|
||||
python eval.py --is_distributed=0 --per_batch_size=1 --pretrained=PATH_CHECKPOINT --data_dir=DATA_DIR > log.txt 2>&1 &
|
||||
|
||||
shell:
|
||||
bash run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT]
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
### Result
|
||||
|
||||
Evaluation result will be stored in the example path, you can find result in `eval.log`.
|
||||
|
||||
## Model Export
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --device_target [DEVICE_TARGET] --file_format[EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
|
||||
# [Model description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ---------------------------------------------- |
|
||||
| Model Version | CSPDarkNet53 |
|
||||
| Resource | Ascend 910; cpu 2.60GHz, 192cores; memory 755G; OS Euler2.8 |
|
||||
| uploaded Date | 06/02/2021 |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | 1200k images |
|
||||
| Batch_size | 64 |
|
||||
| Training Parameters | default_config.yaml |
|
||||
| Optimizer | Momentum |
|
||||
| Loss Function | CrossEntropy |
|
||||
| Outputs | probability |
|
||||
| Loss | 1.78 |
|
||||
| Total time (8p) | 8ps: 14h |
|
||||
| Checkpoint for Fine tuning | 217M (.ckpt file) |
|
||||
| Speed | 8pc: 3977 imgs/sec |
|
||||
| Scripts | [cspdarknet53 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/cspdarknet53) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | CSPDarkNet53 |
|
||||
| Resource | Ascend 910; cpu 2.60GHz, 192cores; memory 755G; OS Euler2.8 |
|
||||
| Uploaded Date | 06/02/2021 |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | 50k images |
|
||||
| Batch_size | 1 |
|
||||
| Outputs | probability |
|
||||
| Accuracy | acc=78.48%(TOP1) |
|
||||
| | acc=94.21%(TOP5) |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We use random seed in "train.py", "./src/utils/var_init.py", "./src/utils/sampler.py".
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,70 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: True
|
||||
modelarts_dataset_unzip_name: "ImageNet"
|
||||
|
||||
# ==============================================================================
|
||||
# default options
|
||||
image_size: "224,224"
|
||||
num_classes: 1000
|
||||
lr: 0.1
|
||||
lr_scheduler: "cosine_annealing"
|
||||
lr_epochs: "30,60,90,120"
|
||||
lr_gamma: 0.1
|
||||
eta_min: 0
|
||||
T_max: 150
|
||||
max_epoch: 150
|
||||
warmup_epochs: 5
|
||||
weight_decay: 0.0001
|
||||
momentum: 0.9
|
||||
is_dynamic_loss_scale: 0
|
||||
loss_scale: 1024
|
||||
label_smooth: 1
|
||||
label_smooth_factor: 0.1
|
||||
ckpt_interval: 1
|
||||
ckpt_save_max: 10
|
||||
ckpt_path: "outputs/"
|
||||
is_save_on_master: 1
|
||||
|
||||
data_dir: ""
|
||||
pretrained: ""
|
||||
is_distributed: 1
|
||||
per_batch_size: 64
|
||||
|
||||
log_path: "outputs/"
|
||||
|
||||
# export options
|
||||
export_batch_size: 1
|
||||
ckpt_file: ""
|
||||
file_name: "cspdarknet53"
|
||||
file_format: "AIR"
|
||||
width: 224
|
||||
height: 224
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: 'Target device type'
|
||||
graph_ckpt: "graph ckpt or feed ckpt"
|
||||
|
||||
# export options
|
||||
export_batch_size: "batch size for export"
|
||||
ckpt_file: "cspdarknet53 ckpt file"
|
||||
file_name: "output air name."
|
||||
file_format: "file format, choices in ['AIR', 'ONNX', 'MINDIR']"
|
||||
width: "input width"
|
||||
height: "input height"
|
|
@ -0,0 +1,156 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluate imagenet."""
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.utils.logging import get_logger
|
||||
from src.utils.auto_mixed_precision import auto_mixed_precision
|
||||
from src.utils.var_init import load_pretrain_model
|
||||
from src.image_classification import CSPDarknet53
|
||||
from src.dataset import create_dataset
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
|
||||
|
||||
def get_top5_acc(top5_arg, gt_class):
|
||||
sub_count = 0
|
||||
for top5, gt in zip(top5_arg, gt_class):
|
||||
if gt in top5:
|
||||
sub_count += 1
|
||||
return sub_count
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
config.log_path = os.path.join(config.output_path, config.log_path)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
'''Eval.'''
|
||||
config.image_size = list(map(int, config.image_size.split(',')))
|
||||
config.rank = get_rank_id()
|
||||
config.group_size = get_device_num()
|
||||
if config.is_distributed or config.group_size > 1:
|
||||
raise ValueError("Not support distribute eval.")
|
||||
config.outputs_dir = os.path.join(config.log_path,
|
||||
datetime.datetime.now().strftime("%Y-%m-%d_time_%H_%M_%S"))
|
||||
config.logger = get_logger(config.outputs_dir, config.rank)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=config.device_target, save_graphs=False, device_id=get_device_id())
|
||||
config.logger.save_args(config)
|
||||
|
||||
# network
|
||||
config.logger.important_info('start create network')
|
||||
de_dataset = create_dataset(config.data_dir, config.image_size, config.per_batch_size,
|
||||
config.rank, config.group_size, mode="eval")
|
||||
eval_dataloader = de_dataset.create_tuple_iterator(output_numpy=True, num_epochs=1)
|
||||
network = CSPDarknet53(num_classes=config.num_classes)
|
||||
load_pretrain_model(config.pretrained, network, config)
|
||||
|
||||
img_tot = 0
|
||||
top1_correct = 0
|
||||
top5_correct = 0
|
||||
if config.device_target == "Ascend":
|
||||
network.to_float(mstype.float16)
|
||||
elif config.device_target == "GPU":
|
||||
auto_mixed_precision(network)
|
||||
else:
|
||||
raise ValueError("Not support device type: {}".format(config.device_target))
|
||||
network.set_train(False)
|
||||
t_start = time.time()
|
||||
for data, gt_classes in eval_dataloader:
|
||||
out = network(Tensor(data, mstype.float32))
|
||||
out = out.asnumpy()
|
||||
|
||||
top1_output = np.argmax(out, (-1))
|
||||
top5_output = np.argsort(out)[:, -5:]
|
||||
|
||||
t1_correct = np.equal(top1_output, gt_classes).sum()
|
||||
top1_correct += t1_correct
|
||||
top5_correct += get_top5_acc(top5_output, gt_classes)
|
||||
img_tot += config.per_batch_size
|
||||
|
||||
t_end = time.time()
|
||||
if config.rank == 0:
|
||||
time_cost = t_end - t_start
|
||||
fps = (img_tot - config.per_batch_size) * config.group_size / time_cost
|
||||
config.logger.info('Inference Performance: {:.2f} img/sec'.format(fps))
|
||||
top1_acc = 100.0 * top1_correct / img_tot
|
||||
top5_acc = 100.0 * top5_correct / img_tot
|
||||
config.logger.info("top1_correct={}, tot={}, acc={:.2f}%(TOP1)".format(top1_correct, img_tot, top1_acc))
|
||||
config.logger.info("top5_correct={}, tot={}, acc={:.2f}%(TOP5)".format(top5_correct, img_tot, top5_acc))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_eval()
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.image_classification import CSPDarknet53
|
||||
|
||||
from model_utils.config import config
|
||||
from model_utils.device_adapter import get_device_id
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
||||
net = CSPDarknet53(num_classes=config.num_classes)
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
input_shape = [config.export_batch_size, 3, config.width, config.height]
|
||||
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=input_shape), ms.float32)
|
||||
|
||||
export(net, input_arr, file_name=config.file_name, file_format=config.file_format)
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.cspdarknet53 import CspDarkNet53
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == 'cspdarknet53':
|
||||
return CspDarkNet53(*args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo, please try 'cspdarknet53'.")
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR] (option)[PATH_CHECKPOINT]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATA_DIR=$2
|
||||
export RANK_TABLE_FILE=$1
|
||||
export RANK_SIZE=8
|
||||
export HCCL_CONNECT_TIMEOUT=600
|
||||
echo "hccl connect timeout has changed to 600 scecond"
|
||||
|
||||
PATH_CHECKPOINT=""
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PATH_CHECKPOINT=$3
|
||||
fi
|
||||
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "the number of logical core" $cores
|
||||
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
|
||||
core_gap=`expr $avg_core_per_rank \- 1`
|
||||
echo "avg_core_per_rank" $avg_core_per_rank
|
||||
echo "core_gap" $core_gap
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
start=`expr $i \* $avg_core_per_rank`
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
export DEPLOY_MODE=0
|
||||
export GE_USE_STATIC_MEMORY=1
|
||||
end=`expr $start \+ $core_gap`
|
||||
cmdopt=$start"-"$end
|
||||
|
||||
rm -rf train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp ../*.yaml ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp -r ../model_utils ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
|
||||
env > env.log
|
||||
taskset -c $cmdopt python train.py \
|
||||
--is_distributed=1 \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export RANK_SIZE=1
|
||||
export RANK_ID=0
|
||||
export DEVICE_ID=$1
|
||||
|
||||
DATA_DIR=$2
|
||||
PATH_CHECKPOINT=$3
|
||||
|
||||
rm -rf ./eval$1
|
||||
mkdir ./eval$1
|
||||
cp ../*.py ./eval$1
|
||||
cp ../*.yaml ./eval$1
|
||||
cp -r ../src ./eval$1
|
||||
cp -r ../model_utils ./eval$1
|
||||
cd ./eval$1 || exit
|
||||
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
python eval.py \
|
||||
--is_distributed=0 \
|
||||
--per_batch_size=1 \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
|
@ -0,0 +1,38 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_standalone_train.sh [DEVICE_ID] [DATA_DIR] (option)[PATH_CHECKPOINT]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export RANK_SIZE=1
|
||||
export RANK_ID=0
|
||||
export DEVICE_ID=$1
|
||||
|
||||
DATA_DIR=$2
|
||||
PATH_CHECKPOINT=""
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PATH_CHECKPOINT=$3
|
||||
fi
|
||||
|
||||
python train.py \
|
||||
--is_distributed=0 \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
|
|
@ -0,0 +1,232 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""CSPDarkNet53 model."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Mish(nn.Cell):
|
||||
"""Mish activation method"""
|
||||
def __init__(self):
|
||||
super(Mish, self).__init__()
|
||||
self.mul = P.Mul()
|
||||
self.tanh = P.Tanh()
|
||||
self.softplus = P.Softplus()
|
||||
|
||||
def construct(self, input_x):
|
||||
res1 = self.softplus(input_x)
|
||||
tanh = self.tanh(res1)
|
||||
output = self.mul(input_x, tanh)
|
||||
|
||||
return output
|
||||
|
||||
def conv_block(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
dilation=1):
|
||||
"""Get a conv2d batchnorm and relu layer"""
|
||||
pad_mode = 'same'
|
||||
padding = 0
|
||||
|
||||
return nn.SequentialCell(
|
||||
[nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
pad_mode=pad_mode),
|
||||
nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5),
|
||||
Mish()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
DarkNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channels: Integer. Input channel.
|
||||
out_channels: Integer. Output channel.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
Examples:
|
||||
ResidualBlock(3, 208)
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels):
|
||||
|
||||
super(ResidualBlock, self).__init__()
|
||||
out_chls = out_channels
|
||||
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
|
||||
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.add(out, identity)
|
||||
|
||||
return out
|
||||
|
||||
class CspDarkNet53(nn.Cell):
|
||||
"""
|
||||
DarkNet V1 network.
|
||||
|
||||
Args:
|
||||
block: Cell. Block for network.
|
||||
layer_nums: List. Numbers of different layers.
|
||||
in_channels: Integer. Input channel.
|
||||
out_channels: Integer. Output channel.
|
||||
num_classes: Integer. Class number. Default:100.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor,(f1,f2,f3,f4,f5).
|
||||
|
||||
Examples:
|
||||
DarkNet(ResidualBlock)
|
||||
"""
|
||||
def __init__(self,
|
||||
block,
|
||||
detect=False):
|
||||
super(CspDarkNet53, self).__init__()
|
||||
|
||||
self.outchannel = 1024
|
||||
self.detect = detect
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.add = P.Add()
|
||||
|
||||
self.conv0 = conv_block(3, 32, kernel_size=3, stride=1)
|
||||
self.conv1 = conv_block(32, 64, kernel_size=3, stride=2)
|
||||
self.conv2 = conv_block(64, 64, kernel_size=1, stride=1)
|
||||
self.conv3 = conv_block(64, 32, kernel_size=1, stride=1)
|
||||
self.conv4 = conv_block(32, 64, kernel_size=3, stride=1)
|
||||
self.conv5 = conv_block(64, 64, kernel_size=1, stride=1)
|
||||
self.conv6 = conv_block(64, 64, kernel_size=1, stride=1)
|
||||
self.conv7 = conv_block(128, 64, kernel_size=1, stride=1)
|
||||
self.conv8 = conv_block(64, 128, kernel_size=3, stride=2)
|
||||
self.conv9 = conv_block(128, 64, kernel_size=1, stride=1)
|
||||
self.conv10 = conv_block(64, 64, kernel_size=1, stride=1)
|
||||
self.conv11 = conv_block(128, 64, kernel_size=1, stride=1)
|
||||
self.conv12 = conv_block(128, 128, kernel_size=1, stride=1)
|
||||
self.conv13 = conv_block(128, 256, kernel_size=3, stride=2)
|
||||
self.conv14 = conv_block(256, 128, kernel_size=1, stride=1)
|
||||
self.conv15 = conv_block(128, 128, kernel_size=1, stride=1)
|
||||
self.conv16 = conv_block(256, 128, kernel_size=1, stride=1)
|
||||
self.conv17 = conv_block(256, 256, kernel_size=1, stride=1)
|
||||
self.conv18 = conv_block(256, 512, kernel_size=3, stride=2)
|
||||
self.conv19 = conv_block(512, 256, kernel_size=1, stride=1)
|
||||
self.conv20 = conv_block(256, 256, kernel_size=1, stride=1)
|
||||
self.conv21 = conv_block(512, 256, kernel_size=1, stride=1)
|
||||
self.conv22 = conv_block(512, 512, kernel_size=1, stride=1)
|
||||
self.conv23 = conv_block(512, 1024, kernel_size=3, stride=2)
|
||||
self.conv24 = conv_block(1024, 512, kernel_size=1, stride=1)
|
||||
self.conv25 = conv_block(512, 512, kernel_size=1, stride=1)
|
||||
self.conv26 = conv_block(1024, 512, kernel_size=1, stride=1)
|
||||
self.conv27 = conv_block(1024, 1024, kernel_size=1, stride=1)
|
||||
|
||||
self.layer2 = self._make_layer(block, 2, in_channel=64, out_channel=64)
|
||||
self.layer3 = self._make_layer(block, 8, in_channel=128, out_channel=128)
|
||||
self.layer4 = self._make_layer(block, 8, in_channel=256, out_channel=256)
|
||||
self.layer5 = self._make_layer(block, 4, in_channel=512, out_channel=512)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel):
|
||||
"""
|
||||
Make Layer for DarkNet.
|
||||
|
||||
:param block: Cell. DarkNet block.
|
||||
:param layer_num: Integer. Layer number.
|
||||
:param in_channel: Integer. Input channel.
|
||||
:param out_channel: Integer. Output channel.
|
||||
:return: SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
_make_layer(ConvBlock, 1, 128, 256)
|
||||
"""
|
||||
layers = []
|
||||
darkblk = block(in_channel, out_channel)
|
||||
layers.append(darkblk)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
darkblk = block(out_channel, out_channel)
|
||||
layers.append(darkblk)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct method"""
|
||||
c1 = self.conv0(x)
|
||||
c2 = self.conv1(c1) #route
|
||||
c3 = self.conv2(c2)
|
||||
c4 = self.conv3(c3)
|
||||
c5 = self.conv4(c4)
|
||||
c6 = self.add(c3, c5)
|
||||
c7 = self.conv5(c6)
|
||||
c8 = self.conv6(c2)
|
||||
c9 = self.concat((c7, c8))
|
||||
c10 = self.conv7(c9)
|
||||
c11 = self.conv8(c10) #route
|
||||
c12 = self.conv9(c11)
|
||||
c13 = self.layer2(c12)
|
||||
c14 = self.conv10(c13)
|
||||
c15 = self.conv11(c11)
|
||||
c16 = self.concat((c14, c15))
|
||||
c17 = self.conv12(c16)
|
||||
c18 = self.conv13(c17) #route
|
||||
c19 = self.conv14(c18)
|
||||
c20 = self.layer3(c19)
|
||||
c21 = self.conv15(c20)
|
||||
c22 = self.conv16(c18)
|
||||
c23 = self.concat((c21, c22))
|
||||
c24 = self.conv17(c23) #output1
|
||||
c25 = self.conv18(c24) #route
|
||||
c26 = self.conv19(c25)
|
||||
c27 = self.layer4(c26)
|
||||
c28 = self.conv20(c27)
|
||||
c29 = self.conv21(c25)
|
||||
c30 = self.concat((c28, c29))
|
||||
c31 = self.conv22(c30) #output2
|
||||
c32 = self.conv23(c31) #route
|
||||
c33 = self.conv24(c32)
|
||||
c34 = self.layer5(c33)
|
||||
c35 = self.conv25(c34)
|
||||
c36 = self.conv26(c32)
|
||||
c37 = self.concat((c35, c36))
|
||||
c38 = self.conv27(c37) #output3
|
||||
|
||||
if self.detect:
|
||||
return c24, c31, c38
|
||||
|
||||
return c38
|
||||
|
||||
def get_out_channels(self):
|
||||
return self.outchannel
|
||||
|
||||
def cspdarknet53():
|
||||
"""
|
||||
Get CSPDarkNet53 neural network.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of CSPDarkNet53 neural network.
|
||||
|
||||
Examples:
|
||||
cspdarknet53()
|
||||
"""
|
||||
return CspDarkNet53(ResidualBlock)
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as V_C
|
||||
from PIL import Image, ImageFile
|
||||
from .utils.sampler import DistributedSampler
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
class TxtDataset():
|
||||
"""
|
||||
create txt dataset.
|
||||
|
||||
Args:
|
||||
Returns:
|
||||
de_dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, root, txt_name):
|
||||
super(TxtDataset, self).__init__()
|
||||
self.imgs = []
|
||||
self.labels = []
|
||||
fin = open(txt_name, 'r')
|
||||
for line in fin:
|
||||
image_name, label = line.strip().split(' ')
|
||||
self.imgs.append(os.path.join(root, image_name))
|
||||
self.labels.append(int(label))
|
||||
fin.close()
|
||||
|
||||
def __getitem__(self, item):
|
||||
img = Image.open(self.imgs[item]).convert('RGB')
|
||||
return img, self.labels[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
||||
|
||||
def create_dataset(data_dir, image_size, per_batch_size, rank, group_size,
|
||||
mode="train",
|
||||
input_mode="folder",
|
||||
root='',
|
||||
num_parallel_workers=None,
|
||||
shuffle=None,
|
||||
sampler=None,
|
||||
class_indexing=None,
|
||||
drop_remainder=True,
|
||||
transform=None,
|
||||
target_transform=None):
|
||||
"create ImageNet dataset."
|
||||
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
if transform is None:
|
||||
if mode == "train":
|
||||
transform_img = [
|
||||
V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
V_C.RandomHorizontalFlip(prob=0.5),
|
||||
V_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
V_C.Normalize(mean=mean, std=std),
|
||||
V_C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
transform_img = [
|
||||
V_C.Decode(),
|
||||
V_C.Resize((256, 256)),
|
||||
V_C.CenterCrop(image_size),
|
||||
V_C.Normalize(mean=mean, std=std),
|
||||
V_C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
transform_img = transform
|
||||
|
||||
if target_transform is None:
|
||||
transform_label = [C.TypeCast(mstype.int32)]
|
||||
else:
|
||||
transform_label = target_transform
|
||||
|
||||
|
||||
if input_mode == 'folder':
|
||||
de_dataset = ds.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers,
|
||||
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
|
||||
num_shards=group_size, shard_id=rank)
|
||||
else:
|
||||
dataset = TxtDataset(root, data_dir)
|
||||
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
|
||||
de_dataset = ds.GeneratorDataset(dataset, ['image', 'label'], sampler=sampler)
|
||||
|
||||
de_dataset = de_dataset.map(operations=transform_img, input_columns="image",
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
de_dataset = de_dataset.map(operations=transform_label, input_columns="label",
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
|
||||
columns_to_project = ['image', 'label']
|
||||
de_dataset = de_dataset.project(columns=columns_to_project)
|
||||
|
||||
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
|
||||
de_dataset = de_dataset.repeat(1)
|
||||
|
||||
return de_dataset
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""common head architecture."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from src.utils.custom_op import GlobalAvgPooling
|
||||
|
||||
__all__ = ["CommonHead"]
|
||||
|
||||
class CommonHead(nn.Cell):
|
||||
"""Common Head."""
|
||||
def __init__(self, num_classes, out_channels):
|
||||
super(CommonHead, self).__init__()
|
||||
self.avgpool = GlobalAvgPooling()
|
||||
self.fc = nn.Dense(out_channels, num_classes, has_bias=True).add_flags_recursive(fp16=True)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.avgpool(x)
|
||||
x = self.fc(x)
|
||||
return x
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Image classification."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from src.cspdarknet53 import cspdarknet53
|
||||
from src.head import CommonHead
|
||||
from src.utils.var_init import default_recurisive_init
|
||||
|
||||
class ImageClassificationNetwork(nn.Cell):
|
||||
"""Architecture of Image Classification Network."""
|
||||
def __init__(self, backbone, head, include_top=True, activation="None"):
|
||||
super(ImageClassificationNetwork, self).__init__()
|
||||
self.backbone = backbone
|
||||
self.include_top = include_top
|
||||
self.need_activation = False
|
||||
if self.include_top:
|
||||
self.head = head
|
||||
if activation != "None":
|
||||
self.need_activation = True
|
||||
if activation == "Sigmoid":
|
||||
self.activation = P.Sigmoid()
|
||||
elif activation == "Softmax":
|
||||
self.activation = P.Softmax()
|
||||
else:
|
||||
raise NotImplementedError("The activation {} not in ['Sigmoid', 'Softmax'].".format(activation))
|
||||
|
||||
def construct(self, x):
|
||||
x = self.backbone(x)
|
||||
if self.include_top:
|
||||
x = self.head(x)
|
||||
if self.need_activation:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
class CSPDarknet53(ImageClassificationNetwork):
|
||||
"""CSPDarknet53 architecture."""
|
||||
def __init__(self, num_classes=1000, include_top=True, activation="None"):
|
||||
backbone = cspdarknet53()
|
||||
out_channels = backbone.get_out_channels()
|
||||
head = CommonHead(num_classes=num_classes, out_channels=out_channels)
|
||||
super(CSPDarknet53, self).__init__(backbone, head, include_top, activation)
|
||||
|
||||
default_recurisive_init(self)
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network."""
|
||||
from mindspore.nn.loss.loss import Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropy(Loss):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
|
||||
def __init__(self, smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logit, label):
|
||||
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, one_hot_label)
|
||||
loss = self.mean(loss, 0)
|
||||
return loss
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import math
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
"""Applies liner decay to generate learning rate array."""
|
||||
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||
"""Applies cosine decay to generate learning rate array."""
|
||||
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
|
||||
"""Applies three steps decay to generate learning rate array."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
milestones = lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr_each_step = []
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma**milestones_steps_counter[i]
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
def get_lr(args):
|
||||
"""generate learning rate array"""
|
||||
|
||||
if args.lr_scheduler == "exponential":
|
||||
lr = warmup_step_lr(args.lr,
|
||||
args.lr_epochs,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
gamma=args.lr_gamma)
|
||||
|
||||
elif args.lr_scheduler == "cosine_annealing":
|
||||
lr = warmup_cosine_annealing_lr(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.T_max,
|
||||
args.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
|
||||
return lr
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""auto mixed precision"""
|
||||
from collections.abc import Iterable
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
def check_type_name(arg_name, arg_type, valid_types, prim_name):
|
||||
"""Checks whether a type in some specified types"""
|
||||
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
||||
|
||||
def raise_error_msg():
|
||||
"""func for raising error message when check failed"""
|
||||
type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
|
||||
num_types = len(valid_types)
|
||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||
raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
|
||||
f"{type_names if num_types > 1 else type_names[0]}, "
|
||||
f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.")
|
||||
|
||||
if isinstance(arg_type, type(mstype.tensor)):
|
||||
arg_type = arg_type.element_type()
|
||||
if arg_type not in valid_types:
|
||||
raise_error_msg()
|
||||
return arg_type
|
||||
|
||||
class OutputTo(nn.Cell):
|
||||
"""Cast cell output back to float16 or float32"""
|
||||
|
||||
def __init__(self, op, to_type=mstype.float16):
|
||||
super(OutputTo, self).__init__(auto_prefix=False)
|
||||
self._op = op
|
||||
check_type_name('to_type', to_type, [mstype.float16, mstype.float32], None)
|
||||
self.to_type = to_type
|
||||
|
||||
def construct(self, x):
|
||||
return F.cast(self._op(x), self.to_type)
|
||||
|
||||
|
||||
def auto_mixed_precision(network):
|
||||
"""Do keep batchnorm fp32."""
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
network.to_float(mstype.float16)
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
elif name == 'fc':
|
||||
network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32))
|
||||
change = True
|
||||
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16))
|
||||
else:
|
||||
auto_mixed_precision(subcell)
|
||||
if isinstance(network, nn.SequentialCell) and change:
|
||||
network.cell_list = list(network.cells())
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""network operations."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class GlobalAvgPooling(nn.Cell):
|
||||
'''GlobalAvgPooling'''
|
||||
def __init__(self):
|
||||
super(GlobalAvgPooling, self).__init__()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.mean(x, (2, 3))
|
||||
return x
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom logger."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
'''Logger.'''
|
||||
def __init__(self, logger_name, local_rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
self.local_rank = local_rank
|
||||
if local_rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, local_rank=0):
|
||||
'''Setup logging file.'''
|
||||
self.local_rank = local_rank
|
||||
if self.local_rank == 0:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '.log'
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
logger = LOGGER("cspdarknet53", rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""optimizer parameters."""
|
||||
|
||||
def get_param_groups(net):
|
||||
"""get param groups"""
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in net.trainable_params():
|
||||
param_name = x.name
|
||||
if param_name.endswith('.bias'):
|
||||
no_decay_params.append(x)
|
||||
elif param_name.endswith('.gamma'):
|
||||
no_decay_params.append(x)
|
||||
elif param_name.endswith('.beta'):
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
choose samples from the dataset
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
class DistributedSampler():
|
||||
"""
|
||||
sampling the dataset.
|
||||
|
||||
Args:
|
||||
Returns:
|
||||
num_samples, number of samples.
|
||||
"""
|
||||
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.group_size = group_size
|
||||
self.dataset_len = len(self.dataset)
|
||||
self.num_samples = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
|
||||
self.total_size = self.num_samples * self.group_size
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.seed = (self.seed + 1) & 0xffffffff
|
||||
np.random.seed(self.seed)
|
||||
indices = np.random.permutation(self.dataset_len).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset_len)))
|
||||
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
indices = indices[self.rank::self.group_size]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Initialize."""
|
||||
import os
|
||||
import math
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import initializer as init
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
|
||||
def _calculate_gain(nonlinearity, param=None):
|
||||
"""calculate_gain"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
res = 0
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
res = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
res = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
res = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return res
|
||||
|
||||
def _assignment(arr, num):
|
||||
"""Assign the value of `num` to `arr`."""
|
||||
if arr.shape == ():
|
||||
arr = arr.reshape((1))
|
||||
arr[:] = num
|
||||
arr.reshape(())
|
||||
else:
|
||||
if isinstance(num, np.ndarray):
|
||||
arr[:] = num[:]
|
||||
else:
|
||||
arr[:] = num
|
||||
return arr
|
||||
|
||||
def _calculate_in_and_out(tensor):
|
||||
"""_calculate_fan_in_and_fan_out"""
|
||||
dimensions = len(tensor.shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
|
||||
fan_in = tensor.shape[1]
|
||||
fan_out = tensor.shape[0]
|
||||
|
||||
if dimensions > 2:
|
||||
counter = reduce(lambda x, y: x * y, tensor.shape[2:])
|
||||
fan_in *= counter
|
||||
fan_out *= counter
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
def _select_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_in_and_out(tensor)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
class KaimingInit(init.Initializer):
|
||||
"""Base class. Initialize the array with HeKaiming init algorithm."""
|
||||
def __init__(self, a=0., mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(KaimingInit, self).__init__()
|
||||
self.mode = mode
|
||||
self.gain = _calculate_gain(nonlinearity, a)
|
||||
|
||||
def _initialize(self, arr):
|
||||
raise NotImplementedError("Init algorithm not-implemented.")
|
||||
|
||||
class KaimingUniform(KaimingInit):
|
||||
"""KaimingUniform init algorithm."""
|
||||
|
||||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
||||
data = np.random.uniform(-bound, bound, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
||||
|
||||
class KaimingNormal(KaimingInit):
|
||||
"""KaimingNormal init algorithm."""
|
||||
|
||||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
std = self.gain / math.sqrt(fan)
|
||||
data = np.random.normal(0, std, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
||||
|
||||
def default_recurisive_init(custom_cell):
|
||||
ms.common.set_seed(0)
|
||||
for _, cell in custom_cell.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5.0)), cell.weight.data.shape,
|
||||
cell.weight.data.dtype).to_tensor())
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight.data.asnumpy())
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.set_data(Tensor(np.random.uniform(-bound, bound, cell.bias.data.shape), cell.bias.data.dtype))
|
||||
elif isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.data.shape,
|
||||
cell.weight.data.dtype).to_tensor())
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight.data.asnumpy())
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.set_data(Tensor(np.random.uniform(-bound, bound, cell.bias.data.shape), cell.bias.data.dtype))
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
pass
|
||||
|
||||
|
||||
def load_pretrain_model(ckpt_file, network, args):
|
||||
"""load pretrain model."""
|
||||
if os.path.isfile(ckpt_file):
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
param_dict_new = {}
|
||||
for k, v in param_dict.items():
|
||||
if k.startswith('moments.'):
|
||||
continue
|
||||
elif k.startswith('network.'):
|
||||
param_dict_new[k[8:]] = v
|
||||
else:
|
||||
param_dict_new[k] = v
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info("Load pretrained {:s} success".format(ckpt_file))
|
||||
else:
|
||||
args.logger.info("Do not load pretrained.")
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train scripts."""
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.loss import CrossEntropy
|
||||
from src.lr_generator import get_lr
|
||||
from src.utils.logging import get_logger
|
||||
from src.utils.optimizers_init import get_param_groups
|
||||
from src.utils.var_init import load_pretrain_model
|
||||
from src.image_classification import CSPDarknet53
|
||||
|
||||
from model_utils.config import config as default_config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
class ProgressMonitor(Callback):
|
||||
"""monitor loss and cost time."""
|
||||
def __init__(self, args):
|
||||
super(ProgressMonitor, self).__init__()
|
||||
self.me_epoch_start_time = 0
|
||||
self.me_epoch_start_step_num = 0
|
||||
self.args = args
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
cur_step_num = cb_params.cur_step_num - 1
|
||||
_epoch = cb_params.cur_epoch_num
|
||||
time_cost = time.time() - self.me_epoch_start_time
|
||||
fps_mean = self.args.per_batch_size * (cur_step_num - self.me_epoch_start_step_num) * \
|
||||
self.args.group_size / time_cost
|
||||
per_step_time = 1000 * time_cost / (cur_step_num - self.me_epoch_start_step_num)
|
||||
self.args.logger.info('epoch[{}], iter[{}], loss: {}, mean_fps: {:.2f}'
|
||||
' imgs/sec, per_step_time: {:.2f} ms'.format(_epoch,
|
||||
cur_step_num % self.args.steps_per_epoch,
|
||||
cb_params.net_outputs,
|
||||
fps_mean,
|
||||
per_step_time))
|
||||
|
||||
self.me_epoch_start_step_num = cur_step_num
|
||||
self.me_epoch_start_time = time.time()
|
||||
|
||||
|
||||
def set_default_args(args):
|
||||
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
||||
args.image_size = list(map(int, args.image_size.split(',')))
|
||||
|
||||
args.rank = get_rank_id()
|
||||
args.group_size = get_device_num()
|
||||
|
||||
args.group_size = get_device_num()
|
||||
if args.is_dynamic_loss_scale == 1:
|
||||
args.loss_scale = 1
|
||||
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime("%Y-%m-%d_time_%H_%M_%S"))
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
return args
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, default_config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if default_config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(default_config.data_path, default_config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(default_config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
default_config.ckpt_path = os.path.join(default_config.output_path, default_config.ckpt_path)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
config = set_default_args(default_config)
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
device_target=config.device_target, save_graphs=False, device_id=get_device_id())
|
||||
if config.is_distributed:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=config.group_size,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
|
||||
# dataloader
|
||||
de_dataset = create_dataset(config.data_dir, config.image_size, config.per_batch_size,
|
||||
config.rank, config.group_size, num_parallel_workers=8)
|
||||
de_dataset.map_model = 4
|
||||
config.steps_per_epoch = de_dataset.get_dataset_size()
|
||||
|
||||
config.logger.save_args(config)
|
||||
|
||||
# network
|
||||
config.logger.important_info('start create network')
|
||||
network = CSPDarknet53(num_classes=config.num_classes)
|
||||
load_pretrain_model(config.pretrained, network, config)
|
||||
|
||||
# lr
|
||||
lr = get_lr(config)
|
||||
|
||||
# optimizer
|
||||
opt = Momentum(params=get_param_groups(network),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay,
|
||||
loss_scale=config.loss_scale)
|
||||
|
||||
# loss
|
||||
if not config.label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
|
||||
|
||||
if config.is_dynamic_loss_scale == 1:
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
||||
else:
|
||||
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
|
||||
metrics={'acc'}, amp_level="O2")
|
||||
|
||||
# checkpoint save
|
||||
progress_cb = ProgressMonitor(config)
|
||||
callbacks = [progress_cb,]
|
||||
if config.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_interval * config.steps_per_epoch,
|
||||
keep_checkpoint_max=config.ckpt_save_max)
|
||||
save_ckpt_path = os.path.join(config.outputs_dir, 'ckpt_' + str(config.rank) + '')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(config.rank))
|
||||
callbacks.append(ckpt_cb)
|
||||
|
||||
model.train(config.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
Loading…
Reference in New Issue