From 55e1538ac62269ff98ec262c42a76dfef0509fd4 Mon Sep 17 00:00:00 2001 From: zhanghuiyao <1814619459@qq.com> Date: Mon, 7 Jun 2021 11:36:29 +0800 Subject: [PATCH] Add cspdarknet53 network to modelzoo --- model_zoo/official/cv/cspdarknet53/README.md | 305 ++++++++++++++++++ .../cv/cspdarknet53/default_config.yaml | 70 ++++ model_zoo/official/cv/cspdarknet53/eval.py | 156 +++++++++ model_zoo/official/cv/cspdarknet53/export.py | 38 +++ .../cv/cspdarknet53/mindspore_hub_conf.py | 21 ++ .../cv/cspdarknet53/model_utils/__init__.py | 0 .../cv/cspdarknet53/model_utils/config.py | 127 ++++++++ .../model_utils/device_adapter.py | 27 ++ .../cspdarknet53/model_utils/local_adapter.py | 36 +++ .../model_utils/moxing_adapter.py | 116 +++++++ .../scripts/run_distribute_train.sh | 66 ++++ .../cv/cspdarknet53/scripts/run_eval.sh | 45 +++ .../scripts/run_standalone_train.sh | 38 +++ .../official/cv/cspdarknet53/src/__init__.py | 0 .../cv/cspdarknet53/src/cspdarknet53.py | 232 +++++++++++++ .../official/cv/cspdarknet53/src/dataset.py | 119 +++++++ .../official/cv/cspdarknet53/src/head.py | 32 ++ .../cspdarknet53/src/image_classification.py | 57 ++++ .../official/cv/cspdarknet53/src/loss.py | 38 +++ .../cv/cspdarknet53/src/lr_generator.py | 92 ++++++ .../cv/cspdarknet53/src/utils/__init__.py | 0 .../src/utils/auto_mixed_precision.py | 71 ++++ .../cv/cspdarknet53/src/utils/custom_op.py | 27 ++ .../cv/cspdarknet53/src/utils/logging.py | 73 +++++ .../cspdarknet53/src/utils/optimizers_init.py | 32 ++ .../cv/cspdarknet53/src/utils/sampler.py | 52 +++ .../cv/cspdarknet53/src/utils/var_init.py | 155 +++++++++ model_zoo/official/cv/cspdarknet53/train.py | 216 +++++++++++++ 28 files changed, 2241 insertions(+) create mode 100644 model_zoo/official/cv/cspdarknet53/README.md create mode 100644 model_zoo/official/cv/cspdarknet53/default_config.yaml create mode 100644 model_zoo/official/cv/cspdarknet53/eval.py create mode 100644 model_zoo/official/cv/cspdarknet53/export.py create mode 100644 model_zoo/official/cv/cspdarknet53/mindspore_hub_conf.py create mode 100644 model_zoo/official/cv/cspdarknet53/model_utils/__init__.py create mode 100644 model_zoo/official/cv/cspdarknet53/model_utils/config.py create mode 100644 model_zoo/official/cv/cspdarknet53/model_utils/device_adapter.py create mode 100644 model_zoo/official/cv/cspdarknet53/model_utils/local_adapter.py create mode 100644 model_zoo/official/cv/cspdarknet53/model_utils/moxing_adapter.py create mode 100644 model_zoo/official/cv/cspdarknet53/scripts/run_distribute_train.sh create mode 100644 model_zoo/official/cv/cspdarknet53/scripts/run_eval.sh create mode 100644 model_zoo/official/cv/cspdarknet53/scripts/run_standalone_train.sh create mode 100644 model_zoo/official/cv/cspdarknet53/src/__init__.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/cspdarknet53.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/dataset.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/head.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/image_classification.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/loss.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/lr_generator.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/utils/__init__.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/utils/auto_mixed_precision.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/utils/custom_op.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/utils/logging.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/utils/optimizers_init.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/utils/sampler.py create mode 100644 model_zoo/official/cv/cspdarknet53/src/utils/var_init.py create mode 100644 model_zoo/official/cv/cspdarknet53/train.py diff --git a/model_zoo/official/cv/cspdarknet53/README.md b/model_zoo/official/cv/cspdarknet53/README.md new file mode 100644 index 00000000000..c17cab716a0 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/README.md @@ -0,0 +1,305 @@ +# Contents + +- [CSPDarkNet53 Description](#CSPDarkNet53-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Features](#features) + - [Mixed Precision](#mixed-precision) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Training Process](#training-process) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) + - [Evaluation Performance](#evaluation-performance) + - [Inference Performance](#inference-performance) +- [Description of Random Situation](#description-of-random-situation) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [CSPDarkNet53 Description](#contents) + +CSPDarkNet53 is a simple, highly modularized network architecture for image classification. It designs results in a homogeneous, multi-branch architecture that has only a few hyper-parameters to set in CSPDarkNet53. + +[Paper](https://arxiv.org/pdf/1911.11929.pdf) Chien-Yao Wang, Hong-Yuan Mark Liao, Yueh-Hua Wu, Ping-Yang Chen, Jun-Wei Hsieh, and I-Hau Yeh. CSPNet: A new backbone that can enhance learning capability of cnn. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshop (CVPR Workshop), 2020. 2, 7 + +# [Model architecture](#contents) + +The overall network architecture of CSPDarkNet53 is show below: + +[Link](https://arxiv.org/pdf/1911.11929.pdf) + +# [Dataset](#contents) + +Dataset used can refer to paper. + +- Dataset size: 125G, 1250k colorful images in 1000 classes + - Train: 120G, 1200k images + - Test: 5G, 50k images +- Data format: RGB images. + - Note: Data will be processed in src/dataset.py + +# [Features](#contents) + +## [Mixed Precision(Ascend)](#contents) + +The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. + +For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. + +# [Environment Requirements](#contents) + +- Hardware(Ascend) +- Prepare hardware environment with Ascend processor. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +# [Quick Start](#contents) + +After installing MindSpore via the official website, you can start training and evaluation as follows: + +- Running local with Ascend + +```python +# run standalone training example with train.py +python train.py --is_distributed=0 --data_dir=$DATA_DIR > log.txt 2>&1 & + +# run distributed training example +bash run_standalone_train.sh [DEVICE_ID] [DATA_DIR] (option)[PATH_CHECKPOINT] + +# run distributed training example +bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR] (option)[PATH_CHECKPOINT] + +# run evaluation example with eval.py +python eval.py --is_distributed=0 --per_batch_size=1 --pretrained=$PATH_CHECKPOINT --data_dir=$DATA_DIR > log.txt 2>&1 & + +# run evaluation example +bash run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT] +``` + +For distributed training, a hccl configuration file with JSON format needs to be created in advance. +Please follow the instructions in the link below: + + +```bash +# Train ImageNet 8p on ModelArts +# (1) Perform a or b. +# a. Set "enable_modelarts=True" on default_config.yaml file. +# Set "is_distributed=1" on default_config.yaml file. +# Set "data_dir='/cache/data/ImageNet/train'" on default_config.yaml file. +# (option)Set "checkpoint_url='s3://dir_to_pretrained/'" on default_config.yaml file. +# (option)Set "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file. +# (option)Set other parameters on default_config.yaml file you need. +# b. Add "enable_modelarts=True" on the website UI interface. +# Add "is_distributed=1" on the website UI interface. +# Add "data_dir='/cache/data/ImageNet/train'" on the website UI interface. +# (option)Add "checkpoint_url='s3://dir_to_pretrained/'" on the website UI interface. +# (option)Add "pretrained='/cache/checkpoint_path/model.ckpt'" on the website UI interface. +# (option)Add other parameters on the website UI interface. +# (2) (option)Upload or copy your pretrained model to S3 bucket if pretrained is set. +# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.) +# (4) Set the code directory to "/path/cspdarknet53" on the website UI interface. +# (5) Set the startup file to "train.py" on the website UI interface. +# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. +# (7) Create your job. +# +# Eval ImageNet 1p on ModelArts +# (1) Perform a or b. +# a. Set "enable_modelarts=True" on default_config.yaml file. +# Set "is_distributed=0" on default_config.yaml file. +# Set "per_batch_size=1" on default_config.yaml file. +# Set "data_dir='/cache/data/ImageNet/validation_preprocess'" on default_config.yaml file. +# Set "checkpoint_url='s3://dir_to_pretrained/'" on default_config.yaml file. +# Set "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file. +# (option)Set other parameters on default_config.yaml file you need. +# b. Add "enable_modelarts=True" on the website UI interface. +# Add "is_distributed=1" on the website UI interface. +# Add "per_batch_size=1" on the website UI interface. +# Add "data_dir='/cache/data/ImageNet/validation_preprocess'" on the website UI interface. +# Add "checkpoint_url='s3://dir_to_pretrained/'" on the website UI interface. +# Add "pretrained='/cache/checkpoint_path/model.ckpt'" on the website UI interface. +# (option)Add other parameters on the website UI interface. +# (2) Upload or copy your trained model to S3 bucket. +# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.) +# (4) Set the code directory to "/path/cspdarknet53" on the website UI interface. +# (5) Set the startup file to "eval.py" on the website UI interface. +# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. +# (7) Create your job. +``` + +# [Script description](#contents) + +## [Script and sample code](#contents) + +```shell +. +└─cspdarknet53 + ├─README.md + ├── model_utils + ├─__init__.py # init file + ├─config.py # Parse arguments + ├─device_adapter.py # Device adapter for ModelArts + ├─local_adapter.py # Local adapter + ├─moxing_adapter.py # Moxing adapter for ModelArts + ├─scripts + ├─run_standalone_train.sh # launch standalone training with ascend platform(1p) + ├─run_distribute_train.sh # launch distributed training with ascend platform(8p) + └─run_eval.sh # launch evaluating with ascend platform + ├─src + ├─utils + ├─__init__.py # modeule init file + ├─auto_mixed_precision.py # Auto mixed precision + ├─custom_op.py # network operations + ├─logging.py # Custom logger + ├─optimizers_init.py # optimizer parameters + ├─sampler.py # choose samples from the dataset + ├─var_init.py # Initialize + ├─__init__.py # parameter configuration + ├─cspdarknet53.py # network definition + ├─dataset.py # data preprocessing + ├─head.py # common head architecture + ├─image_classification.py # Image classification + ├─loss.py # Customized CrossEntropy loss function + ├─lr_generator.py # learning rate generator + ├─mindspore_hub_conf.py # mindspore_hub_conf script + ├─default_config.yaml # Configurations + ├─eval.py # eval net + ├─export.py # convert checkpoint + └─train.py # train net +``` + +## [Script Parameters](#contents) + +```python +Major parameters in default_config.yaml are: +'data_dir' # dataset dir +'pretrained' # checkpoint dir +'is_distributed' # is distribute param +'per_batch_size' # batch size each device +'log_path' # save log file path +``` + +## [Training process](#contents) + +### Usage + +You can start training using python or shell scripts. The usage of shell scripts as follows: + +- Ascend: + +```shell +# distribute training(8p) +bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR] (option)[PATH_CHECKPOINT] +# standalone training +bash run_standalone_train.sh [DEVICE_ID] [DATA_DIR] (option)[PATH_CHECKPOINT] +``` + +> Notes: RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV3, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size. +> +> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh` + +### Launch + +```python +# training example + python: + python train.py --is_distributed=0 --pretrained=PATH_CHECKPOINT --data_dir=DATA_DIR > log.txt 2>&1 & + + shell: + # distribute training example(8p) + bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR] (option)[PATH_CHECKPOINT] + # standalone training example + bash run_standalone_train.sh [DEVICE_ID] [DATA_DIR] (option)[PATH_CHECKPOINT] +``` + +### Result + +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./log.txt`. + +## [Evaluation Process](#contents) + +### Usage + +You can start training using python or shell scripts. The usage of shell scripts as follows: + +- Ascend: + +```shell + bash run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT] +``` + +### Launch + +```python +# eval example + python: + python eval.py --is_distributed=0 --per_batch_size=1 --pretrained=PATH_CHECKPOINT --data_dir=DATA_DIR > log.txt 2>&1 & + + shell: + bash run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT] +``` + +> checkpoint can be produced in training process. + +### Result + +Evaluation result will be stored in the example path, you can find result in `eval.log`. + +## Model Export + +```shell +python export.py --ckpt_file [CKPT_PATH] --device_target [DEVICE_TARGET] --file_format[EXPORT_FORMAT] +``` + +`EXPORT_FORMAT` should be in ["AIR", "MINDIR"] + +# [Model description](#contents) + +## [Performance](#contents) + +### Evaluation Performance + +| Parameters | Ascend | +| -------------------------- | ---------------------------------------------- | +| Model Version | CSPDarkNet53 | +| Resource | Ascend 910; cpu 2.60GHz, 192cores; memory 755G; OS Euler2.8 | +| uploaded Date | 06/02/2021 | +| MindSpore Version | 1.2.0 | +| Dataset | 1200k images | +| Batch_size | 64 | +| Training Parameters | default_config.yaml | +| Optimizer | Momentum | +| Loss Function | CrossEntropy | +| Outputs | probability | +| Loss | 1.78 | +| Total time (8p) | 8ps: 14h | +| Checkpoint for Fine tuning | 217M (.ckpt file) | +| Speed | 8pc: 3977 imgs/sec | +| Scripts | [cspdarknet53 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/cspdarknet53) | + +### Inference Performance + +| Parameters | Ascend | +| ------------------- | --------------------------- | +| Model Version | CSPDarkNet53 | +| Resource | Ascend 910; cpu 2.60GHz, 192cores; memory 755G; OS Euler2.8 | +| Uploaded Date | 06/02/2021 | +| MindSpore Version | 1.2.0 | +| Dataset | 50k images | +| Batch_size | 1 | +| Outputs | probability | +| Accuracy | acc=78.48%(TOP1) | +| | acc=94.21%(TOP5) | + +# [Description of Random Situation](#contents) + +We use random seed in "train.py", "./src/utils/var_init.py", "./src/utils/sampler.py". + +# [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/cspdarknet53/default_config.yaml b/model_zoo/official/cv/cspdarknet53/default_config.yaml new file mode 100644 index 00000000000..20c9ed0e7f3 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/default_config.yaml @@ -0,0 +1,70 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path" +device_target: "Ascend" +need_modelarts_dataset_unzip: True +modelarts_dataset_unzip_name: "ImageNet" + +# ============================================================================== +# default options +image_size: "224,224" +num_classes: 1000 +lr: 0.1 +lr_scheduler: "cosine_annealing" +lr_epochs: "30,60,90,120" +lr_gamma: 0.1 +eta_min: 0 +T_max: 150 +max_epoch: 150 +warmup_epochs: 5 +weight_decay: 0.0001 +momentum: 0.9 +is_dynamic_loss_scale: 0 +loss_scale: 1024 +label_smooth: 1 +label_smooth_factor: 0.1 +ckpt_interval: 1 +ckpt_save_max: 10 +ckpt_path: "outputs/" +is_save_on_master: 1 + +data_dir: "" +pretrained: "" +is_distributed: 1 +per_batch_size: 64 + +log_path: "outputs/" + +# export options +export_batch_size: 1 +ckpt_file: "" +file_name: "cspdarknet53" +file_format: "AIR" +width: 224 +height: 224 + +--- + +# Help description for each configuration +enable_modelarts: "Whether training on modelarts, default: False" +data_url: "Url for modelarts" +train_url: "Url for modelarts" +data_path: "The location of the input data." +output_path: "The location of the output file." +device_target: 'Target device type' +graph_ckpt: "graph ckpt or feed ckpt" + +# export options +export_batch_size: "batch size for export" +ckpt_file: "cspdarknet53 ckpt file" +file_name: "output air name." +file_format: "file format, choices in ['AIR', 'ONNX', 'MINDIR']" +width: "input width" +height: "input height" \ No newline at end of file diff --git a/model_zoo/official/cv/cspdarknet53/eval.py b/model_zoo/official/cv/cspdarknet53/eval.py new file mode 100644 index 00000000000..d37b62ea057 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/eval.py @@ -0,0 +1,156 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluate imagenet.""" +import os +import time +import datetime +import numpy as np + +from mindspore import Tensor, context +from mindspore.common import dtype as mstype + +from src.utils.logging import get_logger +from src.utils.auto_mixed_precision import auto_mixed_precision +from src.utils.var_init import load_pretrain_model +from src.image_classification import CSPDarknet53 +from src.dataset import create_dataset + +from model_utils.config import config +from model_utils.moxing_adapter import moxing_wrapper +from model_utils.device_adapter import get_device_id, get_rank_id, get_device_num + +def get_top5_acc(top5_arg, gt_class): + sub_count = 0 + for top5, gt in zip(top5_arg, gt_class): + if gt in top5: + sub_count += 1 + return sub_count + + +def modelarts_pre_process(): + '''modelarts pre process function.''' + def unzip(zip_file, save_dir): + import zipfile + s_time = time.time() + if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)): + zip_isexist = zipfile.is_zipfile(zip_file) + if zip_isexist: + fz = zipfile.ZipFile(zip_file, 'r') + data_num = len(fz.namelist()) + print("Extract Start...") + print("unzip file num: {}".format(data_num)) + data_print = int(data_num / 100) if data_num > 100 else 1 + i = 0 + for file in fz.namelist(): + if i % data_print == 0: + print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True) + i += 1 + fz.extract(file, save_dir) + print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60), + int(int(time.time() - s_time) % 60))) + print("Extract Done.") + else: + print("This is not zip.") + else: + print("Zip has been extracted.") + + if config.need_modelarts_dataset_unzip: + zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip") + save_dir_1 = os.path.join(config.data_path) + + sync_lock = "/tmp/unzip_sync.lock" + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("Zip file path: ", zip_file_1) + print("Unzip file save dir: ", save_dir_1) + unzip(zip_file_1, save_dir_1) + print("===Finish extract data synchronization===") + try: + os.mknod(sync_lock) + except IOError: + pass + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) + + config.log_path = os.path.join(config.output_path, config.log_path) + + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_eval(): + '''Eval.''' + config.image_size = list(map(int, config.image_size.split(','))) + config.rank = get_rank_id() + config.group_size = get_device_num() + if config.is_distributed or config.group_size > 1: + raise ValueError("Not support distribute eval.") + config.outputs_dir = os.path.join(config.log_path, + datetime.datetime.now().strftime("%Y-%m-%d_time_%H_%M_%S")) + config.logger = get_logger(config.outputs_dir, config.rank) + + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=config.device_target, save_graphs=False, device_id=get_device_id()) + config.logger.save_args(config) + + # network + config.logger.important_info('start create network') + de_dataset = create_dataset(config.data_dir, config.image_size, config.per_batch_size, + config.rank, config.group_size, mode="eval") + eval_dataloader = de_dataset.create_tuple_iterator(output_numpy=True, num_epochs=1) + network = CSPDarknet53(num_classes=config.num_classes) + load_pretrain_model(config.pretrained, network, config) + + img_tot = 0 + top1_correct = 0 + top5_correct = 0 + if config.device_target == "Ascend": + network.to_float(mstype.float16) + elif config.device_target == "GPU": + auto_mixed_precision(network) + else: + raise ValueError("Not support device type: {}".format(config.device_target)) + network.set_train(False) + t_start = time.time() + for data, gt_classes in eval_dataloader: + out = network(Tensor(data, mstype.float32)) + out = out.asnumpy() + + top1_output = np.argmax(out, (-1)) + top5_output = np.argsort(out)[:, -5:] + + t1_correct = np.equal(top1_output, gt_classes).sum() + top1_correct += t1_correct + top5_correct += get_top5_acc(top5_output, gt_classes) + img_tot += config.per_batch_size + + t_end = time.time() + if config.rank == 0: + time_cost = t_end - t_start + fps = (img_tot - config.per_batch_size) * config.group_size / time_cost + config.logger.info('Inference Performance: {:.2f} img/sec'.format(fps)) + top1_acc = 100.0 * top1_correct / img_tot + top5_acc = 100.0 * top5_correct / img_tot + config.logger.info("top1_correct={}, tot={}, acc={:.2f}%(TOP1)".format(top1_correct, img_tot, top1_acc)) + config.logger.info("top5_correct={}, tot={}, acc={:.2f}%(TOP5)".format(top5_correct, img_tot, top5_acc)) + + + +if __name__ == '__main__': + run_eval() diff --git a/model_zoo/official/cv/cspdarknet53/export.py b/model_zoo/official/cv/cspdarknet53/export.py new file mode 100644 index 00000000000..76844d398fd --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/export.py @@ -0,0 +1,38 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""export checkpoint file into air, onnx, mindir models""" +import numpy as np +import mindspore as ms +from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context + +from src.image_classification import CSPDarknet53 + +from model_utils.config import config +from model_utils.device_adapter import get_device_id + +if __name__ == '__main__': + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) + if config.device_target == "Ascend": + context.set_context(device_id=get_device_id()) + + net = CSPDarknet53(num_classes=config.num_classes) + param_dict = load_checkpoint(config.ckpt_file) + load_param_into_net(net, param_dict) + net.set_train(False) + + input_shape = [config.export_batch_size, 3, config.width, config.height] + input_arr = Tensor(np.random.uniform(0.0, 1.0, size=input_shape), ms.float32) + + export(net, input_arr, file_name=config.file_name, file_format=config.file_format) diff --git a/model_zoo/official/cv/cspdarknet53/mindspore_hub_conf.py b/model_zoo/official/cv/cspdarknet53/mindspore_hub_conf.py new file mode 100644 index 00000000000..4b110422822 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/mindspore_hub_conf.py @@ -0,0 +1,21 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.cspdarknet53 import CspDarkNet53 + +def create_network(name, *args, **kwargs): + if name == 'cspdarknet53': + return CspDarkNet53(*args, **kwargs) + raise NotImplementedError(f"{name} is not implemented in the repo, please try 'cspdarknet53'.") diff --git a/model_zoo/official/cv/cspdarknet53/model_utils/__init__.py b/model_zoo/official/cv/cspdarknet53/model_utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/official/cv/cspdarknet53/model_utils/config.py b/model_zoo/official/cv/cspdarknet53/model_utils/config.py new file mode 100644 index 00000000000..2c191e9f748 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/model_utils/config.py @@ -0,0 +1,127 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Parse arguments""" + +import os +import ast +import argparse +from pprint import pprint, pformat +import yaml + +class Config: + """ + Configuration namespace. Convert dictionary to members. + """ + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + + +def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): + """ + Parse command line arguments to the configuration according to the default yaml. + + Args: + parser: Parent parser. + cfg: Base configuration. + helper: Helper description. + cfg_path: Path to the default yaml config. + """ + parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", + parents=[parser]) + helper = {} if helper is None else helper + choices = {} if choices is None else choices + for item in cfg: + if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): + help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) + choice = choices[item] if item in choices else None + if isinstance(cfg[item], bool): + parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, + help=help_description) + else: + parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, + help=help_description) + args = parser.parse_args() + return args + + +def parse_yaml(yaml_path): + """ + Parse the yaml config file. + + Args: + yaml_path: Path to the yaml config. + """ + with open(yaml_path, 'r') as fin: + try: + cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) + cfgs = [x for x in cfgs] + if len(cfgs) == 1: + cfg_helper = {} + cfg = cfgs[0] + cfg_choices = {} + elif len(cfgs) == 2: + cfg, cfg_helper = cfgs + cfg_choices = {} + elif len(cfgs) == 3: + cfg, cfg_helper, cfg_choices = cfgs + else: + raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml") + print(cfg_helper) + except: + raise ValueError("Failed to parse yaml") + return cfg, cfg_helper, cfg_choices + + +def merge(args, cfg): + """ + Merge the base config from yaml file and command line arguments. + + Args: + args: Command line arguments. + cfg: Base configuration. + """ + args_var = vars(args) + for item in args_var: + cfg[item] = args_var[item] + return cfg + + +def get_config(): + """ + Get Config according to the yaml file and cli arguments. + """ + parser = argparse.ArgumentParser(description="default name", add_help=False) + current_dir = os.path.dirname(os.path.abspath(__file__)) + parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"), + help="Config file path") + path_args, _ = parser.parse_known_args() + default, helper, choices = parse_yaml(path_args.config_path) + pprint(default) + args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) + final_config = merge(args, default) + return Config(final_config) + +config = get_config() diff --git a/model_zoo/official/cv/cspdarknet53/model_utils/device_adapter.py b/model_zoo/official/cv/cspdarknet53/model_utils/device_adapter.py new file mode 100644 index 00000000000..7c5d7f837dd --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/model_utils/device_adapter.py @@ -0,0 +1,27 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Device adapter for ModelArts""" + +from .config import config + +if config.enable_modelarts: + from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +else: + from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + +__all__ = [ + "get_device_id", "get_device_num", "get_rank_id", "get_job_id" +] diff --git a/model_zoo/official/cv/cspdarknet53/model_utils/local_adapter.py b/model_zoo/official/cv/cspdarknet53/model_utils/local_adapter.py new file mode 100644 index 00000000000..769fa6dc78e --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/model_utils/local_adapter.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Local adapter""" + +import os + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + return "Local Job" diff --git a/model_zoo/official/cv/cspdarknet53/model_utils/moxing_adapter.py b/model_zoo/official/cv/cspdarknet53/model_utils/moxing_adapter.py new file mode 100644 index 00000000000..25838a7da99 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/model_utils/moxing_adapter.py @@ -0,0 +1,116 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Moxing adapter for ModelArts""" + +import os +import functools +from mindspore import context +from .config import config + +_global_sync_count = 0 + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + job_id = os.getenv('JOB_ID') + job_id = job_id if job_id != "" else "default" + return job_id + +def sync_data(from_path, to_path): + """ + Download data from remote obs to local directory if the first url is remote url and the second one is local path + Upload data from local directory to remote obs in contrast. + """ + import moxing as mox + import time + global _global_sync_count + sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) + _global_sync_count += 1 + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("from path: ", from_path) + print("to path: ", to_path) + mox.file.copy_parallel(from_path, to_path) + print("===finish data synchronization===") + try: + os.mknod(sync_lock) + except IOError: + pass + print("===save flag===") + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Finish sync data from {} to {}.".format(from_path, to_path)) + + +def moxing_wrapper(pre_process=None, post_process=None): + """ + Moxing wrapper to download dataset and upload outputs. + """ + def wrapper(run_func): + @functools.wraps(run_func) + def wrapped_func(*args, **kwargs): + # Download data from data_url + if config.enable_modelarts: + if config.data_url: + sync_data(config.data_url, config.data_path) + print("Dataset downloaded: ", os.listdir(config.data_path)) + if config.checkpoint_url: + sync_data(config.checkpoint_url, config.load_path) + print("Preload downloaded: ", os.listdir(config.load_path)) + if config.train_url: + sync_data(config.train_url, config.output_path) + print("Workspace downloaded: ", os.listdir(config.output_path)) + + context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) + config.device_num = get_device_num() + config.device_id = get_device_id() + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + if pre_process: + pre_process() + + # Run the main function + run_func(*args, **kwargs) + + # Upload data to train_url + if config.enable_modelarts: + if post_process: + post_process() + + if config.train_url: + print("Start to copy output directory") + sync_data(config.output_path, config.train_url) + return wrapped_func + return wrapper diff --git a/model_zoo/official/cv/cspdarknet53/scripts/run_distribute_train.sh b/model_zoo/official/cv/cspdarknet53/scripts/run_distribute_train.sh new file mode 100644 index 00000000000..84399d3fdbf --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/scripts/run_distribute_train.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] && [ $# != 3 ] +then + echo "Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR] (option)[PATH_CHECKPOINT]" +exit 1 +fi + +DATA_DIR=$2 +export RANK_TABLE_FILE=$1 +export RANK_SIZE=8 +export HCCL_CONNECT_TIMEOUT=600 +echo "hccl connect timeout has changed to 600 scecond" + +PATH_CHECKPOINT="" +if [ $# == 3 ] +then + PATH_CHECKPOINT=$3 +fi + +cores=`cat /proc/cpuinfo|grep "processor" |wc -l` +echo "the number of logical core" $cores +avg_core_per_rank=`expr $cores \/ $RANK_SIZE` +core_gap=`expr $avg_core_per_rank \- 1` +echo "avg_core_per_rank" $avg_core_per_rank +echo "core_gap" $core_gap +for((i=0;i env.log + taskset -c $cmdopt python train.py \ + --is_distributed=1 \ + --pretrained=$PATH_CHECKPOINT \ + --data_dir=$DATA_DIR > log.txt 2>&1 & + cd ../ +done diff --git a/model_zoo/official/cv/cspdarknet53/scripts/run_eval.sh b/model_zoo/official/cv/cspdarknet53/scripts/run_eval.sh new file mode 100644 index 00000000000..769392f913f --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/scripts/run_eval.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] +then + echo "Usage: bash run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT]" +exit 1 +fi + +export RANK_SIZE=1 +export RANK_ID=0 +export DEVICE_ID=$1 + +DATA_DIR=$2 +PATH_CHECKPOINT=$3 + +rm -rf ./eval$1 +mkdir ./eval$1 +cp ../*.py ./eval$1 +cp ../*.yaml ./eval$1 +cp -r ../src ./eval$1 +cp -r ../model_utils ./eval$1 +cd ./eval$1 || exit + +echo "start training for rank $RANK_ID, device $DEVICE_ID" +env > env.log + +python eval.py \ + --is_distributed=0 \ + --per_batch_size=1 \ + --pretrained=$PATH_CHECKPOINT \ + --data_dir=$DATA_DIR > log.txt 2>&1 & diff --git a/model_zoo/official/cv/cspdarknet53/scripts/run_standalone_train.sh b/model_zoo/official/cv/cspdarknet53/scripts/run_standalone_train.sh new file mode 100644 index 00000000000..76e65005ea9 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/scripts/run_standalone_train.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] && [ $# != 3 ] +then + echo "Usage: bash run_standalone_train.sh [DEVICE_ID] [DATA_DIR] (option)[PATH_CHECKPOINT]" +exit 1 +fi + +export RANK_SIZE=1 +export RANK_ID=0 +export DEVICE_ID=$1 + +DATA_DIR=$2 +PATH_CHECKPOINT="" +if [ $# == 3 ] +then + PATH_CHECKPOINT=$3 +fi + +python train.py \ + --is_distributed=0 \ + --pretrained=$PATH_CHECKPOINT \ + --data_dir=$DATA_DIR > log.txt 2>&1 & + diff --git a/model_zoo/official/cv/cspdarknet53/src/__init__.py b/model_zoo/official/cv/cspdarknet53/src/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/official/cv/cspdarknet53/src/cspdarknet53.py b/model_zoo/official/cv/cspdarknet53/src/cspdarknet53.py new file mode 100644 index 00000000000..a26c9ff365f --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/cspdarknet53.py @@ -0,0 +1,232 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CSPDarkNet53 model.""" +import mindspore.nn as nn +from mindspore.ops import operations as P + + +class Mish(nn.Cell): + """Mish activation method""" + def __init__(self): + super(Mish, self).__init__() + self.mul = P.Mul() + self.tanh = P.Tanh() + self.softplus = P.Softplus() + + def construct(self, input_x): + res1 = self.softplus(input_x) + tanh = self.tanh(res1) + output = self.mul(input_x, tanh) + + return output + +def conv_block(in_channels, + out_channels, + kernel_size, + stride, + dilation=1): + """Get a conv2d batchnorm and relu layer""" + pad_mode = 'same' + padding = 0 + + return nn.SequentialCell( + [nn.Conv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + pad_mode=pad_mode), + nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5), + Mish() + ] + ) + + +class ResidualBlock(nn.Cell): + """ + DarkNet V1 residual block definition. + + Args: + in_channels: Integer. Input channel. + out_channels: Integer. Output channel. + + Returns: + Tensor, output tensor. + Examples: + ResidualBlock(3, 208) + """ + def __init__(self, + in_channels, + out_channels): + + super(ResidualBlock, self).__init__() + out_chls = out_channels + self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1) + self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1) + self.add = P.Add() + + def construct(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.add(out, identity) + + return out + +class CspDarkNet53(nn.Cell): + """ + DarkNet V1 network. + + Args: + block: Cell. Block for network. + layer_nums: List. Numbers of different layers. + in_channels: Integer. Input channel. + out_channels: Integer. Output channel. + num_classes: Integer. Class number. Default:100. + + Returns: + Tuple, tuple of output tensor,(f1,f2,f3,f4,f5). + + Examples: + DarkNet(ResidualBlock) + """ + def __init__(self, + block, + detect=False): + super(CspDarkNet53, self).__init__() + + self.outchannel = 1024 + self.detect = detect + self.concat = P.Concat(axis=1) + self.add = P.Add() + + self.conv0 = conv_block(3, 32, kernel_size=3, stride=1) + self.conv1 = conv_block(32, 64, kernel_size=3, stride=2) + self.conv2 = conv_block(64, 64, kernel_size=1, stride=1) + self.conv3 = conv_block(64, 32, kernel_size=1, stride=1) + self.conv4 = conv_block(32, 64, kernel_size=3, stride=1) + self.conv5 = conv_block(64, 64, kernel_size=1, stride=1) + self.conv6 = conv_block(64, 64, kernel_size=1, stride=1) + self.conv7 = conv_block(128, 64, kernel_size=1, stride=1) + self.conv8 = conv_block(64, 128, kernel_size=3, stride=2) + self.conv9 = conv_block(128, 64, kernel_size=1, stride=1) + self.conv10 = conv_block(64, 64, kernel_size=1, stride=1) + self.conv11 = conv_block(128, 64, kernel_size=1, stride=1) + self.conv12 = conv_block(128, 128, kernel_size=1, stride=1) + self.conv13 = conv_block(128, 256, kernel_size=3, stride=2) + self.conv14 = conv_block(256, 128, kernel_size=1, stride=1) + self.conv15 = conv_block(128, 128, kernel_size=1, stride=1) + self.conv16 = conv_block(256, 128, kernel_size=1, stride=1) + self.conv17 = conv_block(256, 256, kernel_size=1, stride=1) + self.conv18 = conv_block(256, 512, kernel_size=3, stride=2) + self.conv19 = conv_block(512, 256, kernel_size=1, stride=1) + self.conv20 = conv_block(256, 256, kernel_size=1, stride=1) + self.conv21 = conv_block(512, 256, kernel_size=1, stride=1) + self.conv22 = conv_block(512, 512, kernel_size=1, stride=1) + self.conv23 = conv_block(512, 1024, kernel_size=3, stride=2) + self.conv24 = conv_block(1024, 512, kernel_size=1, stride=1) + self.conv25 = conv_block(512, 512, kernel_size=1, stride=1) + self.conv26 = conv_block(1024, 512, kernel_size=1, stride=1) + self.conv27 = conv_block(1024, 1024, kernel_size=1, stride=1) + + self.layer2 = self._make_layer(block, 2, in_channel=64, out_channel=64) + self.layer3 = self._make_layer(block, 8, in_channel=128, out_channel=128) + self.layer4 = self._make_layer(block, 8, in_channel=256, out_channel=256) + self.layer5 = self._make_layer(block, 4, in_channel=512, out_channel=512) + + def _make_layer(self, block, layer_num, in_channel, out_channel): + """ + Make Layer for DarkNet. + + :param block: Cell. DarkNet block. + :param layer_num: Integer. Layer number. + :param in_channel: Integer. Input channel. + :param out_channel: Integer. Output channel. + :return: SequentialCell, the output layer. + + Examples: + _make_layer(ConvBlock, 1, 128, 256) + """ + layers = [] + darkblk = block(in_channel, out_channel) + layers.append(darkblk) + + for _ in range(1, layer_num): + darkblk = block(out_channel, out_channel) + layers.append(darkblk) + + return nn.SequentialCell(layers) + + def construct(self, x): + """construct method""" + c1 = self.conv0(x) + c2 = self.conv1(c1) #route + c3 = self.conv2(c2) + c4 = self.conv3(c3) + c5 = self.conv4(c4) + c6 = self.add(c3, c5) + c7 = self.conv5(c6) + c8 = self.conv6(c2) + c9 = self.concat((c7, c8)) + c10 = self.conv7(c9) + c11 = self.conv8(c10) #route + c12 = self.conv9(c11) + c13 = self.layer2(c12) + c14 = self.conv10(c13) + c15 = self.conv11(c11) + c16 = self.concat((c14, c15)) + c17 = self.conv12(c16) + c18 = self.conv13(c17) #route + c19 = self.conv14(c18) + c20 = self.layer3(c19) + c21 = self.conv15(c20) + c22 = self.conv16(c18) + c23 = self.concat((c21, c22)) + c24 = self.conv17(c23) #output1 + c25 = self.conv18(c24) #route + c26 = self.conv19(c25) + c27 = self.layer4(c26) + c28 = self.conv20(c27) + c29 = self.conv21(c25) + c30 = self.concat((c28, c29)) + c31 = self.conv22(c30) #output2 + c32 = self.conv23(c31) #route + c33 = self.conv24(c32) + c34 = self.layer5(c33) + c35 = self.conv25(c34) + c36 = self.conv26(c32) + c37 = self.concat((c35, c36)) + c38 = self.conv27(c37) #output3 + + if self.detect: + return c24, c31, c38 + + return c38 + + def get_out_channels(self): + return self.outchannel + +def cspdarknet53(): + """ + Get CSPDarkNet53 neural network. + + Returns: + Cell, cell instance of CSPDarkNet53 neural network. + + Examples: + cspdarknet53() + """ + return CspDarkNet53(ResidualBlock) diff --git a/model_zoo/official/cv/cspdarknet53/src/dataset.py b/model_zoo/official/cv/cspdarknet53/src/dataset.py new file mode 100644 index 00000000000..762f0a33205 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/dataset.py @@ -0,0 +1,119 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import os + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.vision.c_transforms as V_C +from PIL import Image, ImageFile +from .utils.sampler import DistributedSampler + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +class TxtDataset(): + """ + create txt dataset. + + Args: + Returns: + de_dataset. + """ + + def __init__(self, root, txt_name): + super(TxtDataset, self).__init__() + self.imgs = [] + self.labels = [] + fin = open(txt_name, 'r') + for line in fin: + image_name, label = line.strip().split(' ') + self.imgs.append(os.path.join(root, image_name)) + self.labels.append(int(label)) + fin.close() + + def __getitem__(self, item): + img = Image.open(self.imgs[item]).convert('RGB') + return img, self.labels[item] + + def __len__(self): + return len(self.imgs) + + +def create_dataset(data_dir, image_size, per_batch_size, rank, group_size, + mode="train", + input_mode="folder", + root='', + num_parallel_workers=None, + shuffle=None, + sampler=None, + class_indexing=None, + drop_remainder=True, + transform=None, + target_transform=None): + "create ImageNet dataset." + + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + if transform is None: + if mode == "train": + transform_img = [ + V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + V_C.RandomHorizontalFlip(prob=0.5), + V_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4), + V_C.Normalize(mean=mean, std=std), + V_C.HWC2CHW() + ] + else: + transform_img = [ + V_C.Decode(), + V_C.Resize((256, 256)), + V_C.CenterCrop(image_size), + V_C.Normalize(mean=mean, std=std), + V_C.HWC2CHW() + ] + else: + transform_img = transform + + if target_transform is None: + transform_label = [C.TypeCast(mstype.int32)] + else: + transform_label = target_transform + + + if input_mode == 'folder': + de_dataset = ds.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers, + shuffle=shuffle, sampler=sampler, class_indexing=class_indexing, + num_shards=group_size, shard_id=rank) + else: + dataset = TxtDataset(root, data_dir) + sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle) + de_dataset = ds.GeneratorDataset(dataset, ['image', 'label'], sampler=sampler) + + de_dataset = de_dataset.map(operations=transform_img, input_columns="image", + num_parallel_workers=num_parallel_workers) + de_dataset = de_dataset.map(operations=transform_label, input_columns="label", + num_parallel_workers=num_parallel_workers) + + columns_to_project = ['image', 'label'] + de_dataset = de_dataset.project(columns=columns_to_project) + + de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder) + de_dataset = de_dataset.repeat(1) + + return de_dataset diff --git a/model_zoo/official/cv/cspdarknet53/src/head.py b/model_zoo/official/cv/cspdarknet53/src/head.py new file mode 100644 index 00000000000..c725b27f3c7 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/head.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""common head architecture.""" + +import mindspore.nn as nn +from src.utils.custom_op import GlobalAvgPooling + +__all__ = ["CommonHead"] + +class CommonHead(nn.Cell): + """Common Head.""" + def __init__(self, num_classes, out_channels): + super(CommonHead, self).__init__() + self.avgpool = GlobalAvgPooling() + self.fc = nn.Dense(out_channels, num_classes, has_bias=True).add_flags_recursive(fp16=True) + + def construct(self, x): + x = self.avgpool(x) + x = self.fc(x) + return x diff --git a/model_zoo/official/cv/cspdarknet53/src/image_classification.py b/model_zoo/official/cv/cspdarknet53/src/image_classification.py new file mode 100644 index 00000000000..b62a17514d3 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/image_classification.py @@ -0,0 +1,57 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Image classification.""" +import mindspore.nn as nn +from mindspore.ops import operations as P + +from src.cspdarknet53 import cspdarknet53 +from src.head import CommonHead +from src.utils.var_init import default_recurisive_init + +class ImageClassificationNetwork(nn.Cell): + """Architecture of Image Classification Network.""" + def __init__(self, backbone, head, include_top=True, activation="None"): + super(ImageClassificationNetwork, self).__init__() + self.backbone = backbone + self.include_top = include_top + self.need_activation = False + if self.include_top: + self.head = head + if activation != "None": + self.need_activation = True + if activation == "Sigmoid": + self.activation = P.Sigmoid() + elif activation == "Softmax": + self.activation = P.Softmax() + else: + raise NotImplementedError("The activation {} not in ['Sigmoid', 'Softmax'].".format(activation)) + + def construct(self, x): + x = self.backbone(x) + if self.include_top: + x = self.head(x) + if self.need_activation: + x = self.activation(x) + return x + +class CSPDarknet53(ImageClassificationNetwork): + """CSPDarknet53 architecture.""" + def __init__(self, num_classes=1000, include_top=True, activation="None"): + backbone = cspdarknet53() + out_channels = backbone.get_out_channels() + head = CommonHead(num_classes=num_classes, out_channels=out_channels) + super(CSPDarknet53, self).__init__(backbone, head, include_top, activation) + + default_recurisive_init(self) diff --git a/model_zoo/official/cv/cspdarknet53/src/loss.py b/model_zoo/official/cv/cspdarknet53/src/loss.py new file mode 100644 index 00000000000..b2219b220b5 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/loss.py @@ -0,0 +1,38 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network.""" +from mindspore.nn.loss.loss import Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +from mindspore.common import dtype as mstype +import mindspore.nn as nn + + +class CrossEntropy(Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits""" + def __init__(self, smooth_factor=0., num_classes=1000): + super(CrossEntropy, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logit, label): + one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss = self.ce(logit, one_hot_label) + loss = self.mean(loss, 0) + return loss diff --git a/model_zoo/official/cv/cspdarknet53/src/lr_generator.py b/model_zoo/official/cv/cspdarknet53/src/lr_generator.py new file mode 100644 index 00000000000..acca176e9db --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/lr_generator.py @@ -0,0 +1,92 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +from collections import Counter +import numpy as np + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + """Applies liner decay to generate learning rate array.""" + + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """Applies cosine decay to generate learning rate array.""" + + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """Applies three steps decay to generate learning rate array.""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone * steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr * gamma**milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + +def get_lr(args): + """generate learning rate array""" + + if args.lr_scheduler == "exponential": + lr = warmup_step_lr(args.lr, + args.lr_epochs, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + gamma=args.lr_gamma) + + elif args.lr_scheduler == "cosine_annealing": + lr = warmup_cosine_annealing_lr(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.T_max, + args.eta_min) + else: + raise NotImplementedError(args.lr_scheduler) + + return lr diff --git a/model_zoo/official/cv/cspdarknet53/src/utils/__init__.py b/model_zoo/official/cv/cspdarknet53/src/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/official/cv/cspdarknet53/src/utils/auto_mixed_precision.py b/model_zoo/official/cv/cspdarknet53/src/utils/auto_mixed_precision.py new file mode 100644 index 00000000000..a77f8a34a0b --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/utils/auto_mixed_precision.py @@ -0,0 +1,71 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""auto mixed precision""" +from collections.abc import Iterable +import mindspore.nn as nn +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + + +def check_type_name(arg_name, arg_type, valid_types, prim_name): + """Checks whether a type in some specified types""" + valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) + + def raise_error_msg(): + """func for raising error message when check failed""" + type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types] + num_types = len(valid_types) + msg_prefix = f"For '{prim_name}', the" if prim_name else "The" + raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" + f"{type_names if num_types > 1 else type_names[0]}, " + f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.") + + if isinstance(arg_type, type(mstype.tensor)): + arg_type = arg_type.element_type() + if arg_type not in valid_types: + raise_error_msg() + return arg_type + +class OutputTo(nn.Cell): + """Cast cell output back to float16 or float32""" + + def __init__(self, op, to_type=mstype.float16): + super(OutputTo, self).__init__(auto_prefix=False) + self._op = op + check_type_name('to_type', to_type, [mstype.float16, mstype.float32], None) + self.to_type = to_type + + def construct(self, x): + return F.cast(self._op(x), self.to_type) + + +def auto_mixed_precision(network): + """Do keep batchnorm fp32.""" + cells = network.name_cells() + change = False + network.to_float(mstype.float16) + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif name == 'fc': + network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32)) + change = True + elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)): + network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16)) + else: + auto_mixed_precision(subcell) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) diff --git a/model_zoo/official/cv/cspdarknet53/src/utils/custom_op.py b/model_zoo/official/cv/cspdarknet53/src/utils/custom_op.py new file mode 100644 index 00000000000..ddc0599c22e --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/utils/custom_op.py @@ -0,0 +1,27 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""network operations.""" +import mindspore.nn as nn +from mindspore.ops import operations as P + +class GlobalAvgPooling(nn.Cell): + '''GlobalAvgPooling''' + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x diff --git a/model_zoo/official/cv/cspdarknet53/src/utils/logging.py b/model_zoo/official/cv/cspdarknet53/src/utils/logging.py new file mode 100644 index 00000000000..af701bb0067 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/utils/logging.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Custom logger.""" +import logging +import os +import sys +from datetime import datetime + +class LOGGER(logging.Logger): + '''Logger.''' + def __init__(self, logger_name, local_rank=0): + super(LOGGER, self).__init__(logger_name) + self.local_rank = local_rank + if local_rank % 8 == 0: + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + console.setFormatter(formatter) + self.addHandler(console) + + def setup_logging_file(self, log_dir, local_rank=0): + '''Setup logging file.''' + self.local_rank = local_rank + if self.local_rank == 0: + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '.log' + self.log_fn = os.path.join(log_dir, log_name) + fh = logging.FileHandler(self.log_fn) + fh.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + fh.setFormatter(formatter) + self.addHandler(fh) + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and self.local_rank == 0: + self._log(logging.INFO, msg, args, **kwargs) + + def save_args(self, args): + self.info('Args:') + args_dict = vars(args) + for key in args_dict.keys(): + self.info('--> %s: %s', key, args_dict[key]) + self.info('') + + def important_info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and self.local_rank == 0: + line_width = 2 + important_msg = '\n' + important_msg += ('*'*70 + '\n')*line_width + important_msg += ('*'*line_width + '\n')*2 + important_msg += '*'*line_width + ' '*8 + msg + '\n' + important_msg += ('*'*line_width + '\n')*2 + important_msg += ('*'*70 + '\n')*line_width + self.info(important_msg, *args, **kwargs) + + +def get_logger(path, rank): + logger = LOGGER("cspdarknet53", rank) + logger.setup_logging_file(path, rank) + return logger diff --git a/model_zoo/official/cv/cspdarknet53/src/utils/optimizers_init.py b/model_zoo/official/cv/cspdarknet53/src/utils/optimizers_init.py new file mode 100644 index 00000000000..997475859bc --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/utils/optimizers_init.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""optimizer parameters.""" + +def get_param_groups(net): + """get param groups""" + decay_params = [] + no_decay_params = [] + for x in net.trainable_params(): + param_name = x.name + if param_name.endswith('.bias'): + no_decay_params.append(x) + elif param_name.endswith('.gamma'): + no_decay_params.append(x) + elif param_name.endswith('.beta'): + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] diff --git a/model_zoo/official/cv/cspdarknet53/src/utils/sampler.py b/model_zoo/official/cv/cspdarknet53/src/utils/sampler.py new file mode 100644 index 00000000000..1731932120a --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/utils/sampler.py @@ -0,0 +1,52 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +choose samples from the dataset +""" +import math +import numpy as np + +class DistributedSampler(): + """ + sampling the dataset. + + Args: + Returns: + num_samples, number of samples. + """ + def __init__(self, dataset, rank, group_size, shuffle=True, seed=0): + self.dataset = dataset + self.rank = rank + self.group_size = group_size + self.dataset_len = len(self.dataset) + self.num_samples = int(math.ceil(self.dataset_len * 1.0 / self.group_size)) + self.total_size = self.num_samples * self.group_size + self.shuffle = shuffle + self.seed = seed + + def __iter__(self): + if self.shuffle: + self.seed = (self.seed + 1) & 0xffffffff + np.random.seed(self.seed) + indices = np.random.permutation(self.dataset_len).tolist() + else: + indices = list(range(len(self.dataset_len))) + + indices += indices[:(self.total_size - len(indices))] + indices = indices[self.rank::self.group_size] + return iter(indices) + + def __len__(self): + return self.num_samples diff --git a/model_zoo/official/cv/cspdarknet53/src/utils/var_init.py b/model_zoo/official/cv/cspdarknet53/src/utils/var_init.py new file mode 100644 index 00000000000..fd89909f53a --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/src/utils/var_init.py @@ -0,0 +1,155 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Initialize.""" +import os +import math +from functools import reduce +import numpy as np +import mindspore.nn as nn +from mindspore.common import initializer as init +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore as ms +from mindspore import Tensor + +def _calculate_gain(nonlinearity, param=None): + """calculate_gain""" + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + res = 0 + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + res = 1 + elif nonlinearity == 'tanh': + res = 5.0 / 3 + elif nonlinearity == 'relu': + res = math.sqrt(2.0) + elif nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + res = math.sqrt(2.0 / (1 + negative_slope ** 2)) + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + return res + +def _assignment(arr, num): + """Assign the value of `num` to `arr`.""" + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr.reshape(()) + else: + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num + return arr + +def _calculate_in_and_out(tensor): + """_calculate_fan_in_and_fan_out""" + dimensions = len(tensor.shape) + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + + fan_in = tensor.shape[1] + fan_out = tensor.shape[0] + + if dimensions > 2: + counter = reduce(lambda x, y: x * y, tensor.shape[2:]) + fan_in *= counter + fan_out *= counter + + return fan_in, fan_out + +def _select_fan(tensor, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + fan_in, fan_out = _calculate_in_and_out(tensor) + return fan_in if mode == 'fan_in' else fan_out + + +class KaimingInit(init.Initializer): + """Base class. Initialize the array with HeKaiming init algorithm.""" + def __init__(self, a=0., mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingInit, self).__init__() + self.mode = mode + self.gain = _calculate_gain(nonlinearity, a) + + def _initialize(self, arr): + raise NotImplementedError("Init algorithm not-implemented.") + +class KaimingUniform(KaimingInit): + """KaimingUniform init algorithm.""" + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + bound = math.sqrt(3.0) * self.gain / math.sqrt(fan) + data = np.random.uniform(-bound, bound, arr.shape) + + _assignment(arr, data) + + +class KaimingNormal(KaimingInit): + """KaimingNormal init algorithm.""" + + def _initialize(self, arr): + fan = _select_fan(arr, self.mode) + std = self.gain / math.sqrt(fan) + data = np.random.normal(0, std, arr.shape) + + _assignment(arr, data) + + +def default_recurisive_init(custom_cell): + ms.common.set_seed(0) + for _, cell in custom_cell.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5.0)), cell.weight.data.shape, + cell.weight.data.dtype).to_tensor()) + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight.data.asnumpy()) + bound = 1 / math.sqrt(fan_in) + cell.bias.set_data(Tensor(np.random.uniform(-bound, bound, cell.bias.data.shape), cell.bias.data.dtype)) + elif isinstance(cell, nn.Dense): + cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.data.shape, + cell.weight.data.dtype).to_tensor()) + if cell.bias is not None: + fan_in, _ = _calculate_in_and_out(cell.weight.data.asnumpy()) + bound = 1 / math.sqrt(fan_in) + cell.bias.set_data(Tensor(np.random.uniform(-bound, bound, cell.bias.data.shape), cell.bias.data.dtype)) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + pass + + +def load_pretrain_model(ckpt_file, network, args): + """load pretrain model.""" + if os.path.isfile(ckpt_file): + param_dict = load_checkpoint(ckpt_file) + param_dict_new = {} + for k, v in param_dict.items(): + if k.startswith('moments.'): + continue + elif k.startswith('network.'): + param_dict_new[k[8:]] = v + else: + param_dict_new[k] = v + load_param_into_net(network, param_dict_new) + args.logger.info("Load pretrained {:s} success".format(ckpt_file)) + else: + args.logger.info("Do not load pretrained.") diff --git a/model_zoo/official/cv/cspdarknet53/train.py b/model_zoo/official/cv/cspdarknet53/train.py new file mode 100644 index 00000000000..1215b10a147 --- /dev/null +++ b/model_zoo/official/cv/cspdarknet53/train.py @@ -0,0 +1,216 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train scripts.""" +import os +import time +import datetime + +from mindspore import Tensor +from mindspore import context +from mindspore.context import ParallelMode +from mindspore.nn.optim import Momentum +from mindspore.communication.management import init +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback +from mindspore.train.model import Model +from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager +from mindspore.common import set_seed + +from src.dataset import create_dataset +from src.loss import CrossEntropy +from src.lr_generator import get_lr +from src.utils.logging import get_logger +from src.utils.optimizers_init import get_param_groups +from src.utils.var_init import load_pretrain_model +from src.image_classification import CSPDarknet53 + +from model_utils.config import config as default_config +from model_utils.moxing_adapter import moxing_wrapper +from model_utils.device_adapter import get_device_id, get_rank_id, get_device_num + + +set_seed(1) + + +class ProgressMonitor(Callback): + """monitor loss and cost time.""" + def __init__(self, args): + super(ProgressMonitor, self).__init__() + self.me_epoch_start_time = 0 + self.me_epoch_start_step_num = 0 + self.args = args + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + cur_step_num = cb_params.cur_step_num - 1 + _epoch = cb_params.cur_epoch_num + time_cost = time.time() - self.me_epoch_start_time + fps_mean = self.args.per_batch_size * (cur_step_num - self.me_epoch_start_step_num) * \ + self.args.group_size / time_cost + per_step_time = 1000 * time_cost / (cur_step_num - self.me_epoch_start_step_num) + self.args.logger.info('epoch[{}], iter[{}], loss: {}, mean_fps: {:.2f}' + ' imgs/sec, per_step_time: {:.2f} ms'.format(_epoch, + cur_step_num % self.args.steps_per_epoch, + cb_params.net_outputs, + fps_mean, + per_step_time)) + + self.me_epoch_start_step_num = cur_step_num + self.me_epoch_start_time = time.time() + + +def set_default_args(args): + args.lr_epochs = list(map(int, args.lr_epochs.split(','))) + args.image_size = list(map(int, args.image_size.split(','))) + + args.rank = get_rank_id() + args.group_size = get_device_num() + + args.group_size = get_device_num() + if args.is_dynamic_loss_scale == 1: + args.loss_scale = 1 + + args.rank_save_ckpt_flag = 0 + if args.is_save_on_master: + if args.rank == 0: + args.rank_save_ckpt_flag = 1 + else: + args.rank_save_ckpt_flag = 1 + + args.outputs_dir = os.path.join(args.ckpt_path, + datetime.datetime.now().strftime("%Y-%m-%d_time_%H_%M_%S")) + args.logger = get_logger(args.outputs_dir, args.rank) + return args + + +def modelarts_pre_process(): + '''modelarts pre process function.''' + def unzip(zip_file, save_dir): + import zipfile + s_time = time.time() + if not os.path.exists(os.path.join(save_dir, default_config.modelarts_dataset_unzip_name)): + zip_isexist = zipfile.is_zipfile(zip_file) + if zip_isexist: + fz = zipfile.ZipFile(zip_file, 'r') + data_num = len(fz.namelist()) + print("Extract Start...") + print("unzip file num: {}".format(data_num)) + data_print = int(data_num / 100) if data_num > 100 else 1 + i = 0 + for file in fz.namelist(): + if i % data_print == 0: + print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True) + i += 1 + fz.extract(file, save_dir) + print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60), + int(int(time.time() - s_time) % 60))) + print("Extract Done.") + else: + print("This is not zip.") + else: + print("Zip has been extracted.") + + if default_config.need_modelarts_dataset_unzip: + zip_file_1 = os.path.join(default_config.data_path, default_config.modelarts_dataset_unzip_name + ".zip") + save_dir_1 = os.path.join(default_config.data_path) + + sync_lock = "/tmp/unzip_sync.lock" + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("Zip file path: ", zip_file_1) + print("Unzip file save dir: ", save_dir_1) + unzip(zip_file_1, save_dir_1) + print("===Finish extract data synchronization===") + try: + os.mknod(sync_lock) + except IOError: + pass + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) + + default_config.ckpt_path = os.path.join(default_config.output_path, default_config.ckpt_path) + + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_train(): + config = set_default_args(default_config) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=config.device_target, save_graphs=False, device_id=get_device_id()) + if config.is_distributed: + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=config.group_size, + gradients_mean=True) + init() + + # dataloader + de_dataset = create_dataset(config.data_dir, config.image_size, config.per_batch_size, + config.rank, config.group_size, num_parallel_workers=8) + de_dataset.map_model = 4 + config.steps_per_epoch = de_dataset.get_dataset_size() + + config.logger.save_args(config) + + # network + config.logger.important_info('start create network') + network = CSPDarknet53(num_classes=config.num_classes) + load_pretrain_model(config.pretrained, network, config) + + # lr + lr = get_lr(config) + + # optimizer + opt = Momentum(params=get_param_groups(network), + learning_rate=Tensor(lr), + momentum=config.momentum, + weight_decay=config.weight_decay, + loss_scale=config.loss_scale) + + # loss + if not config.label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.num_classes) + + if config.is_dynamic_loss_scale == 1: + loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) + else: + loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + + model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, + metrics={'acc'}, amp_level="O2") + + # checkpoint save + progress_cb = ProgressMonitor(config) + callbacks = [progress_cb,] + if config.rank_save_ckpt_flag: + ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_interval * config.steps_per_epoch, + keep_checkpoint_max=config.ckpt_save_max) + save_ckpt_path = os.path.join(config.outputs_dir, 'ckpt_' + str(config.rank) + '') + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=save_ckpt_path, + prefix='{}'.format(config.rank)) + callbacks.append(ckpt_cb) + + model.train(config.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True) + + + +if __name__ == '__main__': + run_train()