forked from mindspore-Ecosystem/mindspore
commit
e5e9816cd3
|
@ -0,0 +1,224 @@
|
|||
# Contents
|
||||
|
||||
- [NTS-Net Description](#NTS-Net-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Knowledge Distillation Process](#knowledge-distillation-process)
|
||||
- [Prediction Process](#prediction-process)
|
||||
- [Evaluation with cityscape dataset](#evaluation-with-cityscape-dataset)
|
||||
- [Export MindIR](#export-mindir)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [NTS-Net Description](#contents)
|
||||
|
||||
NTS-Net for Navigator-Teacher-Scrutinizer Network, consists of a Navigator agent, a Teacher agent and a Scrutinizer agent. In consideration of intrinsic consistency between informativeness of the regions and their probability being ground-truth class, NTS-Net designs a novel training paradigm, which enables Navigator to detect most informative regions under the guidance from Teacher. After that, the Scrutinizer scrutinizes the proposed regions from Navigator and makes predictions
|
||||
[Paper](https://arxiv.org/abs/1809.00287): Z. Yang, T. Luo, D. Wang, Z. Hu, J. Gao, and L. Wang, Learning to navigate for fine-grained classification, in Proceedings of the European Conference on Computer Vision (ECCV), 2018.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
NTS-Net consists of a Navigator agent, a Teacher agent and a Scrutinizer agent. The Navigator navigates the model to focus on the most informative regions: for each region in the image, Navigator predicts how informative the region is, and the predictions are used to propose the most informative regions. The Teacher evaluates the regions proposed by Navigator and provides feedbacks: for each proposed region, the Teacher evaluates its probability belonging to ground-truth class; the confidence evaluations guide the Navigator to propose more informative regions with a novel ordering-consistent loss function. The Scrutinizer scrutinizes proposed regions from Navigator and makes fine-grained classifications: each proposed region is enlarged to the same size and the Scrutinizer extracts features therein; the features of regions and of the whole image are jointly processed to make fine-grained classifications.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [Caltech-UCSD Birds-200-2011](<http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>)
|
||||
|
||||
Please download the datasets [CUB_200_2011.tgz] and unzip it, then put all training images into a directory named "train", put all testing images into a directory named "test".
|
||||
|
||||
The directory structure is as follows:
|
||||
|
||||
```path
|
||||
.
|
||||
└─cub_200_2011
|
||||
├─train
|
||||
└─test
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware Ascend
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```shell
|
||||
.
|
||||
└─ntsnet
|
||||
├─README.md # README
|
||||
├─scripts # shell script
|
||||
├─run_standalone_train.sh # training in standalone mode(1pcs)
|
||||
├─run_distribute_train.sh # training in parallel mode(8 pcs)
|
||||
└─run_eval.sh # evaluation
|
||||
├─src
|
||||
├─config.py # network configuration
|
||||
├─dataset.py # dataset utils
|
||||
├─lr_generator.py # leanring rate generator
|
||||
├─network.py # network define for ntsnet
|
||||
└─resnet.py # resnet.py
|
||||
├─mindspore_hub_conf.py # mindspore hub interface
|
||||
├─export.py # script to export MINDIR model
|
||||
├─eval.py # evaluation scripts
|
||||
└─train.py # training scripts
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
### [Training Script Parameters](#contents)
|
||||
|
||||
```shell
|
||||
# distributed training
|
||||
Usage: bash run_train.sh [RANK_TABLE_FILE] [DATA_URL] [TRAIN_URL]
|
||||
|
||||
# standalone training
|
||||
Usage: bash run_standalone_train.sh [DATA_URL] [TRAIN_URL]
|
||||
```
|
||||
|
||||
### [Parameters Configuration](#contents)
|
||||
|
||||
```txt
|
||||
"img_width": 448, # width of the input images
|
||||
"img_height": 448, # height of the input images
|
||||
|
||||
# anchor
|
||||
"size": [48, 96, 192], #anchor base size
|
||||
"scale": [1, 2 ** (1. / 3.), 2 ** (2. / 3.)], #anchor base scale
|
||||
"aspect_ratio": [0.667, 1, 1.5], #anchor base aspect_ratio
|
||||
"stride": [32, 64, 128], #anchor base stride
|
||||
|
||||
# resnet
|
||||
"resnet_block": [3, 4, 6, 3], # block number in each layer
|
||||
"resnet_in_channels": [64, 256, 512, 1024], # in channel size for each layer
|
||||
"resnet_out_channels": [256, 512, 1024, 2048], # out channel size for each layer
|
||||
|
||||
# LR
|
||||
"base_lr": 0.001, # base learning rate
|
||||
"base_step": 58633, # bsae step in lr generator
|
||||
"total_epoch": 200, # total epoch in lr generator
|
||||
"warmup_step": 4, # warmp up step in lr generator
|
||||
"sgd_momentum": 0.9, # momentum in optimizer
|
||||
|
||||
# train
|
||||
"batch_size": 8,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 200, # total epoch size
|
||||
"save_checkpoint": True, # whether save checkpoint or not
|
||||
"save_checkpoint_epochs": 1, # save checkpoint interval
|
||||
"num_classes": 200
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
- Set options in `config.py`, including learning rate, output filename and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
||||
|
||||
### [Training](#content)
|
||||
|
||||
- Run `run_standalone_train.sh` for non-distributed training of NTS-Net model.
|
||||
|
||||
```bash
|
||||
# standalone training
|
||||
bash run_standalone_train.sh [DATA_URL] [TRAIN_URL]
|
||||
```
|
||||
|
||||
### [Distributed Training](#content)
|
||||
|
||||
- Run `run_distribute_train.sh` for distributed training of NTS-Net model.
|
||||
|
||||
```bash
|
||||
bash run_train.sh [RANK_TABLE_FILE] [DATA_URL] [TRAIN_URL]
|
||||
```
|
||||
|
||||
- Notes
|
||||
1. hccl.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
2. As for PRETRAINED_MODEL,it should be a trained ResNet50 checkpoint.
|
||||
|
||||
### [Training Result](#content)
|
||||
|
||||
Training result will be stored in train_url path. You can find checkpoint file together with result like the following in loss.log.
|
||||
|
||||
```bash
|
||||
# distribute training result(8p)
|
||||
epoch: 1 step: 750 ,loss: 30.88018
|
||||
epoch: 2 step: 750 ,loss: 26.73352
|
||||
epoch: 3 step: 750 ,loss: 22.76208
|
||||
epoch: 4 step: 750 ,loss: 20.52259
|
||||
epoch: 5 step: 750 ,loss: 19.34843
|
||||
epoch: 6 step: 750 ,loss: 17.74093
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### [Evaluation](#content)
|
||||
|
||||
- Run `run_eval.sh` for evaluation.
|
||||
|
||||
```bash
|
||||
# infer
|
||||
sh run_eval.sh [DATA_URL] [TRAIN_URL] [CKPT_FILENAME]
|
||||
```
|
||||
|
||||
### [Evaluation result](#content)
|
||||
|
||||
Inference result will be stored in the train_url path. Under this, you can find result like the following in eval.log.
|
||||
|
||||
```bash
|
||||
ckpt file name: ntsnet-112_750.ckpt
|
||||
accuracy: 0.876
|
||||
```
|
||||
|
||||
## Model Export
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --device_target [DEVICE_TARGET] --file_format[EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
`EXPORT_FORMAT` should be "MINDIR"
|
||||
|
||||
# Model Description
|
||||
|
||||
## Performance
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
|
||||
| uploaded Date | 16/04/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | cub200-2011 |
|
||||
| Training Parameters | epoch=200, batch_size = 8 |
|
||||
| Optimizer | SGD |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| Output | predict class |
|
||||
| Loss | 10.9852 |
|
||||
| Speed | 1pc: 130 ms/step; 8pcs: 138 ms/step |
|
||||
| Total time | 8pcs: 5.93 hours |
|
||||
| Parameters | 87.6 |
|
||||
| Checkpoint for Fine tuning | 333.07M(.ckpt file) |
|
||||
| Scripts | [ntsnet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ntsnet) |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We use random seed in train.py and eval.py for weight initialization.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ntsnet eval."""
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
from mindspore import context, set_seed, Tensor, load_checkpoint, load_param_into_net, ops
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from src.config import config
|
||||
from src.dataset import create_dataset_test
|
||||
from src.network import NTS_NET
|
||||
parser = argparse.ArgumentParser(description='ntsnet eval running')
|
||||
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
|
||||
parser.add_argument('--data_url', default=None, help='Directory contains CUB_200_2011 dataset.')
|
||||
parser.add_argument('--train_url', default=None, help='Directory contains checkpoint file and eval.log')
|
||||
parser.add_argument('--ckpt_filename', default=None, help='checkpoint file name')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
args = parser.parse_args()
|
||||
run_modelart = args.run_modelart
|
||||
if not run_modelart:
|
||||
device_id = args.device_id
|
||||
else:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
batch_size = config.batch_size
|
||||
|
||||
resnet50Path = ""
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
if run_modelart:
|
||||
import moxing as mox
|
||||
|
||||
local_input_url = '/cache/data' + str(device_id)
|
||||
local_output_url = '/cache/ckpt' + str(device_id)
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_input_url)
|
||||
mox.file.copy_parallel(src_url=os.path.join(args.train_url, args.ckpt_filename),
|
||||
dst_url=os.path.join(local_output_url, args.ckpt_filename))
|
||||
mox.file.copy_parallel(src_url=os.path.join(args.train_url, "eval.log"),
|
||||
dst_url=os.path.join(local_output_url, "eval.log"))
|
||||
else:
|
||||
local_input_url = args.data_url
|
||||
local_output_url = args.train_url
|
||||
|
||||
|
||||
def print2file(obj1, obj2):
|
||||
with open(os.path.join(local_output_url, 'eval.log'), 'a') as f:
|
||||
f.write(str(obj1))
|
||||
f.write(' ')
|
||||
f.write(str(obj2))
|
||||
f.write(' \r\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
set_seed(1)
|
||||
test_data_set = create_dataset_test(test_path=os.path.join(local_input_url, "CUB_200_2011/test"),
|
||||
batch_size=batch_size)
|
||||
test_data_loader = test_data_set.create_dict_iterator(output_numpy=True)
|
||||
|
||||
ntsnet = NTS_NET(topK=6, resnet50Path=resnet50Path)
|
||||
param_dict = load_checkpoint(os.path.join(local_output_url, args.ckpt_filename))
|
||||
load_param_into_net(ntsnet, param_dict)
|
||||
ntsnet.set_train(False)
|
||||
success_num = 0.0
|
||||
total_num = 0.0
|
||||
for _, data in enumerate(test_data_loader):
|
||||
image_data = Tensor(data['image'], mstype.float32)
|
||||
label = Tensor(data["label"], mstype.int32)
|
||||
_, scrutinizer_out, _, _ = ntsnet(image_data)
|
||||
result_label, _ = ops.ArgMaxWithValue(1)(scrutinizer_out)
|
||||
success_num = success_num + sum((result_label == label).asnumpy())
|
||||
total_num = total_num + float(image_data.shape[0])
|
||||
print2file("ckpt file name: ", args.ckpt_filename)
|
||||
print2file("accuracy: ", round(success_num / total_num, 3))
|
||||
if run_modelart:
|
||||
mox.file.copy_parallel(src_url=os.path.join(local_output_url, "eval.log"),
|
||||
dst_url=os.path.join(args.train_url, "eval.log"))
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.network import NTS_NET
|
||||
|
||||
parser = argparse.ArgumentParser(description='ntsnet export')
|
||||
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file name.")
|
||||
parser.add_argument('--data_url', default=None, help='Directory contains CUB_200_2011 dataset.')
|
||||
parser.add_argument('--train_url', default=None, help='Directory contains checkpoint file')
|
||||
parser.add_argument("--file_name", type=str, default="ntsnet", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, default="MINDIR", help="file format")
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU', 'CPU'], help='device target (default: Ascend)')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
if args.run_modelart:
|
||||
import moxing as mox
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
local_output_url = '/cache/ckpt' + str(device_id)
|
||||
mox.file.copy_parallel(src_url=os.path.join(args.train_url, args.ckpt_file),
|
||||
dst_url=os.path.join(local_output_url, args.ckpt_file))
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = NTS_NET(topK=6)
|
||||
if args.run_modelart:
|
||||
param_dict = load_checkpoint(os.path.join(local_output_url, args.ckpt_file))
|
||||
else:
|
||||
param_dict = load_checkpoint(os.path.join(args.train_url, args.ckpt_file))
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
inputs = Tensor(np.random.rand(args.batch_size, 3, 448, 448), mstype.float32)
|
||||
export(net, inputs, file_name=args.file_name, file_format=args.file_format)
|
||||
if args.run_modelart:
|
||||
file_name = args.file_name + "." + args.file_format.lower()
|
||||
mox.file.copy_parallel(src_url=file_name,
|
||||
dst_url=os.path.join(args.train_url, file_name))
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.network import NTS_NET
|
||||
|
||||
|
||||
def create_network(name):
|
||||
if name == "ntsnet":
|
||||
return NTS_NET(topK=6)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,87 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_train.sh [RANK_TABLE_FILE] [DATA_URL] [TRAIN_URL]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
PATH3=$(get_real_path $3)
|
||||
|
||||
echo $PATH1
|
||||
echo $PATH2
|
||||
echo $PATH3
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: DATA_URL=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH3 ]
|
||||
then
|
||||
echo "error: TRAIN_URL=$PATH3 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export HCCL_CONNECT_TIMEOUT=600
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
echo 3 > /proc/sys/vm/drop_caches
|
||||
|
||||
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
|
||||
avg=`expr $cpus \/ $DEVICE_NUM`
|
||||
gap=`expr $avg \- 1`
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
start=`expr $i \* $avg`
|
||||
end=`expr $start \+ $gap`
|
||||
cmdopt=$start"-"$end
|
||||
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
taskset -c $cmdopt python train.py --device_id=$i --run_distribute=True --device_num=$DEVICE_NUM \
|
||||
--data_url=$PATH2 --train_url=$PATH3 &> log &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,72 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DATA_URL] [TRAIN_URL] [CKPT_FILENAME]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
PATH3=$(get_real_path $3)
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATA_URL=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: TRAIN_URL=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH3 ]
|
||||
then
|
||||
echo "error: CKPT_FILENAME=$PATH3 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --device_id=$DEVICE_ID --data_url=$PATH1 --train_url=$PATH2 \
|
||||
--ckpt_filename=$PATH3 &> log &
|
||||
cd ..
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: bash run_standalone_train.sh [DATA_URL] [TRAIN_URL]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATA_URL=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: TRAIN_URL=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --data_url=$PATH1 --train_url=$PATH2 &> log &
|
||||
cd ..
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 112,
|
||||
"keep_checkpoint_max": 10000,
|
||||
"learning_rate": 0.001,
|
||||
"m_for_scrutinizer": 4,
|
||||
"topK": 6,
|
||||
"input_size": (448, 448),
|
||||
"weight_decay": 1e-4,
|
||||
"momentum": 0.9,
|
||||
"num_epochs": 112,
|
||||
"num_classes": 200,
|
||||
"num_train_images": 5994,
|
||||
"num_test_images": 5794,
|
||||
"batch_size": 8,
|
||||
"prefix": "ntsnet",
|
||||
"lossLogName": "loss.log"
|
||||
})
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ntsnet dataset"""
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from mindspore.dataset.vision import Inter
|
||||
|
||||
|
||||
def create_dataset_train(train_path, batch_size):
|
||||
"""create train dataset"""
|
||||
train_data_set = ds.ImageFolderDataset(train_path, shuffle=True)
|
||||
# define map operations
|
||||
transform_img = [
|
||||
vision.Decode(),
|
||||
vision.Resize([448, 448], Inter.LINEAR),
|
||||
vision.RandomHorizontalFlip(),
|
||||
vision.HWC2CHW()
|
||||
]
|
||||
train_data_set = train_data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img,
|
||||
output_columns="image")
|
||||
train_data_set = train_data_set.map(input_columns="image", num_parallel_workers=8,
|
||||
operations=lambda x: (x / 255).astype("float32"))
|
||||
train_data_set = train_data_set.batch(batch_size)
|
||||
return train_data_set
|
||||
|
||||
|
||||
def create_dataset_test(test_path, batch_size):
|
||||
"""create test dataset"""
|
||||
test_data_set = ds.ImageFolderDataset(test_path, shuffle=False)
|
||||
# define map operations
|
||||
transform_img = [
|
||||
vision.Decode(),
|
||||
vision.Resize([448, 448], Inter.LINEAR),
|
||||
vision.HWC2CHW()
|
||||
]
|
||||
test_data_set = test_data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img,
|
||||
output_columns="image")
|
||||
test_data_set = test_data_set.map(input_columns="image", num_parallel_workers=8,
|
||||
operations=lambda x: (x / 255).astype("float32"))
|
||||
test_data_set = test_data_set.batch(batch_size)
|
||||
return test_data_set
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr generator for ntsnet"""
|
||||
import numpy as np
|
||||
|
||||
def get_lr(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
elif i < 100 * steps_per_epoch:
|
||||
lr = lr_max
|
||||
else:
|
||||
lr = lr_max * 0.1
|
||||
lr_each_step.append(lr)
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
return learning_rate
|
|
@ -0,0 +1,531 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ntsnet network wrapper."""
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
from mindspore import ops, load_checkpoint, load_param_into_net, Tensor, nn
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.callback._callback import set_cur_net
|
||||
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \
|
||||
_chg_ckpt_file_name_if_same_exist
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import save_checkpoint, _save_graph
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
||||
from src.resnet import resnet50
|
||||
from src.config import config
|
||||
|
||||
m_for_scrutinizer = config.m_for_scrutinizer
|
||||
K = config.topK
|
||||
input_size = config.input_size
|
||||
num_classes = config.num_classes
|
||||
lossLogName = config.lossLogName
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel):
|
||||
'''Weight init for dense cell'''
|
||||
stdv = 1 / math.sqrt(in_channel)
|
||||
weight = Tensor(np.random.uniform(-stdv, stdv, (out_channel, in_channel)).astype(np.float32))
|
||||
bias = Tensor(np.random.uniform(-stdv, stdv, (out_channel)).astype(np.float32))
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True,
|
||||
weight_init=weight, bias_init=bias).to_float(mstype.float32)
|
||||
|
||||
|
||||
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
|
||||
"""Conv2D wrapper."""
|
||||
shape = (out_channels, in_channels, kernel_size, kernel_size)
|
||||
stdv = 1 / math.sqrt(in_channels * kernel_size * kernel_size)
|
||||
weights = Tensor(np.random.uniform(-stdv, stdv, shape).astype(np.float32))
|
||||
shape_bias = (out_channels,)
|
||||
biass = Tensor(np.random.uniform(-stdv, stdv, shape_bias).astype(np.float32))
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
pad_mode=pad_mode, weight_init=weights, has_bias=True, bias_init=biass)
|
||||
|
||||
|
||||
_default_anchors_setting = (
|
||||
dict(layer='p3', stride=32, size=48, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
dict(layer='p4', stride=64, size=96, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
dict(layer='p5', stride=128, size=192, scale=[1, 2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
|
||||
)
|
||||
|
||||
|
||||
def generate_default_anchor_maps(anchors_setting=None, input_shape=input_size):
|
||||
"""
|
||||
generate default anchor
|
||||
|
||||
:param anchors_setting: all information of anchors
|
||||
:param input_shape: shape of input images, e.g. (h, w)
|
||||
:return: center_anchors: # anchors * 4 (oy, ox, h, w)
|
||||
edge_anchors: # anchors * 4 (y0, x0, y1, x1)
|
||||
anchor_area: # anchors * 1 (area)
|
||||
"""
|
||||
if anchors_setting is None:
|
||||
anchors_setting = _default_anchors_setting
|
||||
|
||||
center_anchors = np.zeros((0, 4), dtype=np.float32)
|
||||
edge_anchors = np.zeros((0, 4), dtype=np.float32)
|
||||
anchor_areas = np.zeros((0,), dtype=np.float32)
|
||||
input_shape = np.array(input_shape, dtype=int)
|
||||
|
||||
for anchor_info in anchors_setting:
|
||||
stride = anchor_info['stride']
|
||||
size = anchor_info['size']
|
||||
scales = anchor_info['scale']
|
||||
aspect_ratios = anchor_info['aspect_ratio']
|
||||
|
||||
output_map_shape = np.ceil(input_shape.astype(np.float32) / stride)
|
||||
output_map_shape = output_map_shape.astype(np.int)
|
||||
output_shape = tuple(output_map_shape) + (4,)
|
||||
ostart = stride / 2.
|
||||
oy = np.arange(ostart, ostart + stride * output_shape[0], stride)
|
||||
oy = oy.reshape(output_shape[0], 1)
|
||||
ox = np.arange(ostart, ostart + stride * output_shape[1], stride)
|
||||
ox = ox.reshape(1, output_shape[1])
|
||||
center_anchor_map_template = np.zeros(output_shape, dtype=np.float32)
|
||||
center_anchor_map_template[:, :, 0] = oy
|
||||
center_anchor_map_template[:, :, 1] = ox
|
||||
for scale in scales:
|
||||
for aspect_ratio in aspect_ratios:
|
||||
center_anchor_map = center_anchor_map_template.copy()
|
||||
center_anchor_map[:, :, 2] = size * scale / float(aspect_ratio) ** 0.5
|
||||
center_anchor_map[:, :, 3] = size * scale * float(aspect_ratio) ** 0.5
|
||||
edge_anchor_map = np.concatenate((center_anchor_map[..., :2] - center_anchor_map[..., 2:4] / 2.,
|
||||
center_anchor_map[..., :2] + center_anchor_map[..., 2:4] / 2.),
|
||||
axis=-1)
|
||||
anchor_area_map = center_anchor_map[..., 2] * center_anchor_map[..., 3]
|
||||
center_anchors = np.concatenate((center_anchors, center_anchor_map.reshape(-1, 4)))
|
||||
edge_anchors = np.concatenate((edge_anchors, edge_anchor_map.reshape(-1, 4)))
|
||||
anchor_areas = np.concatenate((anchor_areas, anchor_area_map.reshape(-1)))
|
||||
return center_anchors, edge_anchors, anchor_areas
|
||||
|
||||
|
||||
class Navigator(nn.Cell):
|
||||
"""Navigator"""
|
||||
|
||||
def __init__(self):
|
||||
"""Navigator init"""
|
||||
super(Navigator, self).__init__()
|
||||
self.down1 = _conv(2048, 128, 3, 1, padding=1, pad_mode='pad')
|
||||
self.down2 = _conv(128, 128, 3, 2, padding=1, pad_mode='pad')
|
||||
self.down3 = _conv(128, 128, 3, 2, padding=1, pad_mode='pad')
|
||||
self.ReLU = nn.ReLU()
|
||||
self.tidy1 = _conv(128, 6, 1, 1, padding=0, pad_mode='same')
|
||||
self.tidy2 = _conv(128, 6, 1, 1, padding=0, pad_mode='same')
|
||||
self.tidy3 = _conv(128, 9, 1, 1, padding=0, pad_mode='same')
|
||||
self.opConcat = ops.Concat(axis=1)
|
||||
self.opReshape = ops.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
"""Navigator construct"""
|
||||
batch_size = x.shape[0]
|
||||
d1 = self.ReLU(self.down1(x))
|
||||
d2 = self.ReLU(self.down2(d1))
|
||||
d3 = self.ReLU(self.down3(d2))
|
||||
t1 = self.tidy1(d1)
|
||||
t2 = self.tidy2(d2)
|
||||
t3 = self.tidy3(d3)
|
||||
t1 = self.opReshape(t1, (batch_size, -1, 1))
|
||||
t2 = self.opReshape(t2, (batch_size, -1, 1))
|
||||
t3 = self.opReshape(t3, (batch_size, -1, 1))
|
||||
return self.opConcat((t1, t2, t3))
|
||||
|
||||
|
||||
class NTS_NET(nn.Cell):
|
||||
"""Ntsnet"""
|
||||
|
||||
def __init__(self, topK=6, resnet50Path=""):
|
||||
"""Ntsnet init"""
|
||||
super(NTS_NET, self).__init__()
|
||||
feature_extractor = resnet50(1001)
|
||||
if resnet50Path != "":
|
||||
param_dict = load_checkpoint(resnet50Path)
|
||||
load_param_into_net(feature_extractor, param_dict)
|
||||
self.feature_extractor = feature_extractor # Backbone
|
||||
self.feature_extractor.end_point = _fc(512 * 4, num_classes)
|
||||
self.navigator = Navigator() # Navigator
|
||||
self.topK = topK
|
||||
self.num_classes = num_classes
|
||||
self.scrutinizer = _fc(2048 * (m_for_scrutinizer + 1), num_classes) # Scrutinizer
|
||||
self.teacher = _fc(512 * 4, num_classes) # Teacher
|
||||
_, edge_anchors, _ = generate_default_anchor_maps()
|
||||
self.pad_side = 224
|
||||
self.Pad_ops = ops.Pad(((0, 0), (0, 0), (self.pad_side, self.pad_side), (self.pad_side, self.pad_side)))
|
||||
self.np_edge_anchors = edge_anchors + 224
|
||||
self.edge_anchors = Tensor(self.np_edge_anchors, mstype.float32)
|
||||
self.opzeros = ops.Zeros()
|
||||
self.opones = ops.Ones()
|
||||
self.concat_op = ops.Concat(axis=1)
|
||||
self.nms = P.NMSWithMask(0.25)
|
||||
self.topK_op = ops.TopK(sorted=True)
|
||||
self.opReshape = ops.Reshape()
|
||||
self.opResizeLinear = ops.ResizeBilinear((224, 224))
|
||||
self.transpose = ops.Transpose()
|
||||
self.opsCropResize = ops.CropAndResize(method="bilinear_v2")
|
||||
self.min_float_num = -65536.0
|
||||
self.selected_mask_shape = (1614,)
|
||||
self.unchosen_score = Tensor(self.min_float_num * np.ones(self.selected_mask_shape, np.float32),
|
||||
mstype.float32)
|
||||
self.gatherND = ops.GatherNd()
|
||||
self.gatherD = ops.GatherD()
|
||||
self.squeezeop = P.Squeeze()
|
||||
self.select = P.Select()
|
||||
self.perm = (1, 2, 0)
|
||||
self.box_index = self.opzeros(((K,)), mstype.int32)
|
||||
self.crop_size = (224, 224)
|
||||
self.perm2 = (0, 3, 1, 2)
|
||||
self.m_for_scrutinizer = m_for_scrutinizer
|
||||
self.sortop = ops.Sort(descending=True)
|
||||
self.stackop = ops.Stack()
|
||||
|
||||
def construct(self, x):
|
||||
"""Ntsnet construct"""
|
||||
resnet_out, rpn_feature, feature = self.feature_extractor(x)
|
||||
x_pad = self.Pad_ops(x)
|
||||
batch_size = x.shape[0]
|
||||
rpn_feature = F.stop_gradient(rpn_feature)
|
||||
rpn_score = self.navigator(rpn_feature)
|
||||
edge_anchors = self.edge_anchors
|
||||
top_k_info = []
|
||||
current_img_for_teachers = []
|
||||
for i in range(batch_size):
|
||||
# using navigator output as scores to nms anchors
|
||||
rpn_score_current_img = self.opReshape(rpn_score[i:i + 1:1, ::], (-1, 1))
|
||||
bbox_score = self.squeezeop(rpn_score_current_img)
|
||||
bbox_score_sorted, bbox_score_sorted_indices = self.sortop(bbox_score)
|
||||
bbox_score_sorted_concat = self.opReshape(bbox_score_sorted, (-1, 1))
|
||||
edge_anchors_sorted_concat = self.gatherND(edge_anchors,
|
||||
self.opReshape(bbox_score_sorted_indices, (1614, 1)))
|
||||
bbox = self.concat_op((edge_anchors_sorted_concat, bbox_score_sorted_concat))
|
||||
_, _, selected_mask = self.nms(bbox)
|
||||
selected_mask = F.stop_gradient(selected_mask)
|
||||
bbox_score = self.squeezeop(bbox_score_sorted_concat)
|
||||
scores_using = self.select(selected_mask, bbox_score, self.unchosen_score)
|
||||
# select the topk anchors and scores after nms
|
||||
_, topK_indices = self.topK_op(scores_using, self.topK)
|
||||
topK_indices = self.opReshape(topK_indices, (K, 1))
|
||||
bbox_topk = self.gatherND(bbox, topK_indices)
|
||||
top_k_info.append(self.opReshape(bbox_topk[::, 4:5:1], (-1,)))
|
||||
# crop from x_pad and resize to a fixed size using bilinear
|
||||
temp_pad = self.opReshape(x_pad[i:i + 1:1, ::, ::, ::], (3, 896, 896))
|
||||
temp_pad = self.transpose(temp_pad, self.perm)
|
||||
tensor_image = self.opReshape(temp_pad, (1,) + temp_pad.shape)
|
||||
tensor_box = self.gatherND(edge_anchors_sorted_concat, topK_indices)
|
||||
tensor_box = tensor_box / 895
|
||||
current_img_for_teacher = self.opsCropResize(tensor_image, tensor_box, self.box_index, self.crop_size)
|
||||
# the image cropped will be used to extractor feature and calculate loss
|
||||
current_img_for_teacher = self.opReshape(current_img_for_teacher, (-1, 224, 224, 3))
|
||||
current_img_for_teacher = self.transpose(current_img_for_teacher, self.perm2)
|
||||
current_img_for_teacher = self.opReshape(current_img_for_teacher, (-1, 3, 224, 224))
|
||||
current_img_for_teachers.append(current_img_for_teacher)
|
||||
feature = self.opReshape(feature, (batch_size, 1, -1))
|
||||
top_k_info = self.stackop(top_k_info)
|
||||
top_k_info = self.opReshape(top_k_info, (batch_size, self.topK))
|
||||
current_img_for_teachers = self.stackop(current_img_for_teachers)
|
||||
current_img_for_teachers = self.opReshape(current_img_for_teachers, (batch_size * self.topK, 3, 224, 224))
|
||||
current_img_for_teachers = F.stop_gradient(current_img_for_teachers)
|
||||
# extracor features of topk cropped images
|
||||
_, _, pre_teacher_features = self.feature_extractor(current_img_for_teachers)
|
||||
pre_teacher_features = self.opReshape(pre_teacher_features, (batch_size, self.topK, 2048))
|
||||
pre_scrutinizer_features = pre_teacher_features[::, 0:self.m_for_scrutinizer:1, ::]
|
||||
pre_scrutinizer_features = self.opReshape(pre_scrutinizer_features, (batch_size, self.m_for_scrutinizer, 2048))
|
||||
pre_scrutinizer_features = self.opReshape(self.concat_op((pre_scrutinizer_features, feature)), (batch_size, -1))
|
||||
# using topk cropped images, feed in scrutinzer and teacher, calculate loss
|
||||
scrutinizer_out = self.scrutinizer(pre_scrutinizer_features)
|
||||
teacher_out = self.teacher(pre_teacher_features)
|
||||
return resnet_out, scrutinizer_out, teacher_out, top_k_info
|
||||
# (batch_size, 200),(batch_size, 200),(batch_size,6, 200),(batch_size,6)
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""WithLossCell wrapper for ntsnet"""
|
||||
|
||||
def __init__(self, backbone, loss_fn):
|
||||
"""WithLossCell init"""
|
||||
super(WithLossCell, self).__init__(auto_prefix=True)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
self.oneTensor = Tensor(1.0, mstype.float32)
|
||||
self.zeroTensor = Tensor(0.0, mstype.float32)
|
||||
self.opReshape = ops.Reshape()
|
||||
self.opOnehot = ops.OneHot()
|
||||
self.oplogsoftmax = ops.LogSoftmax()
|
||||
self.opZeros = ops.Zeros()
|
||||
self.opOnes = ops.Ones()
|
||||
self.opRelu = ops.ReLU()
|
||||
self.opGatherD = ops.GatherD()
|
||||
self.squeezeop = P.Squeeze()
|
||||
self.reducesumop = ops.ReduceSum()
|
||||
self.oprepeat = ops.repeat_elements
|
||||
self.cast = ops.Cast()
|
||||
|
||||
def construct(self, image_data, label):
|
||||
"""WithLossCell construct"""
|
||||
batch_size = image_data.shape[0]
|
||||
origin_label = label
|
||||
labelx = self.opReshape(label, (-1, 1))
|
||||
origin_label_repeatk_2D = self.oprepeat(labelx, rep=K, axis=1)
|
||||
origin_label_repeatk = self.opReshape(origin_label_repeatk_2D, (-1,))
|
||||
origin_label_repeatk_unsqueeze = self.opReshape(origin_label_repeatk_2D, (-1, 1))
|
||||
resnet_out, scrutinizer_out, teacher_out, top_k_info = self._backbone(image_data)
|
||||
teacher_out = self.opReshape(teacher_out, (batch_size * K, -1))
|
||||
log_softmax_teacher_out = -1 * self.oplogsoftmax(teacher_out)
|
||||
log_softmax_teacher_out_result = self.opGatherD(log_softmax_teacher_out, 1, origin_label_repeatk_unsqueeze)
|
||||
log_softmax_teacher_out_result = self.opReshape(log_softmax_teacher_out_result, (batch_size, K))
|
||||
oneHotLabel = self.opOnehot(origin_label, num_classes, self.oneTensor, self.zeroTensor)
|
||||
# using resnet_out to calculate resnet_real_out_loss
|
||||
resnet_real_out_loss = self._loss_fn(resnet_out, oneHotLabel)
|
||||
# using scrutinizer_out to calculate scrutinizer_out_loss
|
||||
scrutinizer_out_loss = self._loss_fn(scrutinizer_out, oneHotLabel)
|
||||
# using teacher_out and top_k_info to calculate ranking loss
|
||||
loss = self.opZeros((), mstype.float32)
|
||||
num = top_k_info.shape[0]
|
||||
for i in range(K):
|
||||
log_softmax_teacher_out_inlabel_unsqueeze = self.opReshape(log_softmax_teacher_out_result[::, i:i + 1:1],
|
||||
(-1, 1))
|
||||
compareX = log_softmax_teacher_out_result > log_softmax_teacher_out_inlabel_unsqueeze
|
||||
pivot = self.opReshape(top_k_info[::, i:i + 1:1], (-1, 1))
|
||||
information = 1 - pivot + top_k_info
|
||||
loss_p = information * compareX
|
||||
loss_p_temp = self.opRelu(loss_p)
|
||||
loss_p = self.reducesumop(loss_p_temp)
|
||||
loss += loss_p
|
||||
rank_loss = loss / num
|
||||
oneHotLabel2 = self.opOnehot(origin_label_repeatk, num_classes, self.oneTensor, self.zeroTensor)
|
||||
# using teacher_out to calculate teacher_loss
|
||||
teacher_loss = self._loss_fn(teacher_out, oneHotLabel2)
|
||||
total_loss = resnet_real_out_loss + rank_loss + scrutinizer_out_loss + teacher_loss
|
||||
return total_loss
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
"""WithLossCell backbone"""
|
||||
return self._backbone
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
"""
|
||||
The checkpoint callback class.
|
||||
It is called to combine with train process and save the model and network parameters after training.
|
||||
Note:
|
||||
In the distributed training scenario, please specify different directories for each training process
|
||||
to save the checkpoint file. Otherwise, the training may fail.
|
||||
Args:
|
||||
prefix (str): The prefix name of checkpoint files. Default: "CKP".
|
||||
directory (str): The path of the folder which will be saved in the checkpoint file. Default: None.
|
||||
ckconfig (CheckpointConfig): Checkpoint strategy configuration. Default: None.
|
||||
Raises:
|
||||
ValueError: If the prefix is invalid.
|
||||
TypeError: If the config is not CheckpointConfig type.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix='CKP', directory=None, ckconfig=None,
|
||||
device_num=1, device_id=0, args=None, run_modelart=False):
|
||||
super(ModelCheckpoint, self).__init__()
|
||||
self._latest_ckpt_file_name = ""
|
||||
self._init_time = time.time()
|
||||
self._last_time = time.time()
|
||||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = 0
|
||||
self.run_modelart = run_modelart
|
||||
if _check_file_name_prefix(prefix):
|
||||
self._prefix = prefix
|
||||
else:
|
||||
raise ValueError("Prefix {} for checkpoint file name invalid, "
|
||||
"please check and correct it and then continue.".format(prefix))
|
||||
if directory is not None:
|
||||
self._directory = _make_directory(directory)
|
||||
else:
|
||||
self._directory = _cur_dir
|
||||
if ckconfig is None:
|
||||
self._config = CheckpointConfig()
|
||||
else:
|
||||
if not isinstance(ckconfig, CheckpointConfig):
|
||||
raise TypeError("ckconfig should be CheckpointConfig type.")
|
||||
self._config = ckconfig
|
||||
# get existing checkpoint files
|
||||
self._manager = CheckpointManager()
|
||||
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
||||
self._graph_saved = False
|
||||
self._need_flush_from_cache = True
|
||||
self.device_num = device_num
|
||||
self.device_id = device_id
|
||||
self.args = args
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Save the checkpoint at the end of step.
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
if _is_role_pserver():
|
||||
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
|
||||
cb_params = run_context.original_args()
|
||||
_make_directory(self._directory)
|
||||
# save graph (only once)
|
||||
if not self._graph_saved:
|
||||
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
||||
if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
|
||||
os.remove(graph_file_name)
|
||||
_save_graph(cb_params.train_network, graph_file_name)
|
||||
self._graph_saved = True
|
||||
thread_list = threading.enumerate()
|
||||
for thread in thread_list:
|
||||
if thread.getName() == "asyn_save_ckpt":
|
||||
thread.join()
|
||||
self._save_ckpt(cb_params)
|
||||
|
||||
def end(self, run_context):
|
||||
"""
|
||||
Save the last checkpoint after training finished.
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
_to_save_last_ckpt = True
|
||||
self._save_ckpt(cb_params, _to_save_last_ckpt)
|
||||
thread_list = threading.enumerate()
|
||||
for thread in thread_list:
|
||||
if thread.getName() == "asyn_save_ckpt":
|
||||
thread.join()
|
||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||
destroy_allgather_cell()
|
||||
|
||||
def _check_save_ckpt(self, cb_params, force_to_save):
|
||||
"""Check whether save checkpoint files or not."""
|
||||
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
|
||||
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
|
||||
or force_to_save is True:
|
||||
return True
|
||||
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
|
||||
self._cur_time = time.time()
|
||||
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
|
||||
self._last_time = self._cur_time
|
||||
return True
|
||||
return False
|
||||
|
||||
def _save_ckpt(self, cb_params, force_to_save=False):
|
||||
"""Save checkpoint files."""
|
||||
if cb_params.cur_step_num == self._last_triggered_step:
|
||||
return
|
||||
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
|
||||
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
||||
if save_ckpt:
|
||||
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
||||
+ str(step_num_in_epoch) + ".ckpt"
|
||||
# update checkpoint file list.
|
||||
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
||||
# keep checkpoint files number equal max number.
|
||||
if self._config.keep_checkpoint_max and \
|
||||
0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
|
||||
self._manager.remove_oldest_ckpoint_file()
|
||||
elif self._config.keep_checkpoint_per_n_minutes and \
|
||||
self._config.keep_checkpoint_per_n_minutes > 0:
|
||||
self._cur_time_for_keep = time.time()
|
||||
if (self._cur_time_for_keep - self._last_time_for_keep) \
|
||||
< self._config.keep_checkpoint_per_n_minutes * 60:
|
||||
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
|
||||
self._cur_time_for_keep)
|
||||
# generate the new checkpoint file and rename it.
|
||||
cur_file = os.path.join(self._directory, cur_ckpoint_file)
|
||||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = cb_params.cur_step_num
|
||||
if context.get_context("enable_ge"):
|
||||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
network = self._config.saved_network if self._config.saved_network is not None \
|
||||
else cb_params.train_network
|
||||
save_checkpoint(network, cur_file, self._config.integrated_save,
|
||||
self._config.async_save)
|
||||
self._latest_ckpt_file_name = cur_file
|
||||
if self.run_modelart and (self.device_num == 1 or self.device_id == 0):
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=cur_file, dst_url=os.path.join(self.args.train_url, cur_ckpoint_file))
|
||||
|
||||
def _flush_from_cache(self, cb_params):
|
||||
"""Flush cache data to host if tensor is cache enable."""
|
||||
has_cache_params = False
|
||||
params = cb_params.train_network.get_parameters()
|
||||
for param in params:
|
||||
if param.cache_enable:
|
||||
has_cache_params = True
|
||||
Tensor(param).flush_from_cache()
|
||||
if not has_cache_params:
|
||||
self._need_flush_from_cache = False
|
||||
|
||||
@property
|
||||
def latest_ckpt_file_name(self):
|
||||
"""Return the latest checkpoint path and file name."""
|
||||
return self._latest_ckpt_file_name
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1, rank_id=0, local_output_url="",
|
||||
device_num=1, device_id=0, args=None, run_modelart=False):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.count = 0
|
||||
self.rpn_loss_sum = 0
|
||||
self.rpn_cls_loss_sum = 0
|
||||
self.rpn_reg_loss_sum = 0
|
||||
self.rank_id = rank_id
|
||||
self.local_output_url = local_output_url
|
||||
self.device_num = device_num
|
||||
self.device_id = device_id
|
||||
self.args = args
|
||||
self.time_stamp_first = time.time()
|
||||
self.run_modelart = run_modelart
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Called after each step finished.
|
||||
Args:
|
||||
run_context (RunContext): Include some information of the model.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
rpn_loss = cb_params.net_outputs.asnumpy()
|
||||
self.count += 1
|
||||
self.rpn_loss_sum += float(rpn_loss)
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
if self.count >= 1:
|
||||
time_stamp_current = time.time()
|
||||
rpn_loss = self.rpn_loss_sum / self.count
|
||||
loss_file = open(os.path.join(self.local_output_url, lossLogName), "a+")
|
||||
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f" %
|
||||
(time_stamp_current - self.time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
rpn_loss))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
if self.run_modelart and (self.device_num == 1 or self.device_id == 0):
|
||||
import moxing as mox
|
||||
mox.file.copy_parallel(src_url=os.path.join(self.local_output_url, lossLogName),
|
||||
dst_url=os.path.join(self.args.train_url, lossLogName))
|
|
@ -0,0 +1,453 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ResNet."""
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.tensor import Tensor
|
||||
from scipy.stats import truncnorm
|
||||
|
||||
|
||||
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
|
||||
fan_in = in_channel * kernel_size * kernel_size
|
||||
scale = 1.0
|
||||
scale /= max(1., fan_in)
|
||||
stddev = (scale ** 0.5) / .87962566103423978
|
||||
mu, sigma = 0, stddev
|
||||
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
|
||||
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
|
||||
return Tensor(weight, dtype=mstype.float32)
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
"""calculate_gain"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
res = 0
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
res = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
res = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
res = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return res
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
"""_calculate_fan_in_and_fan_out"""
|
||||
dimensions = len(tensor)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = tensor[1]
|
||||
fan_out = tensor[0]
|
||||
else:
|
||||
num_input_fmaps = tensor[1]
|
||||
num_output_fmaps = tensor[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = tensor[2] * tensor[3]
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1, use_se=False):
|
||||
if use_se:
|
||||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1, use_se=False):
|
||||
if use_se:
|
||||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1, use_se=False):
|
||||
if use_se:
|
||||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _bn_last(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel, use_se=False):
|
||||
if use_se:
|
||||
weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
|
||||
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel)
|
||||
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
use_se (bool): enable SE-ResNet50 net. Default: False.
|
||||
se_block(bool): use se block in SE-ResNet50 net. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1,
|
||||
use_se=False, se_block=False):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.use_se = use_se
|
||||
self.se_block = se_block
|
||||
channel = out_channel // self.expansion
|
||||
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
|
||||
self.bn1 = _bn(channel)
|
||||
if self.use_se and self.stride != 1:
|
||||
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
|
||||
nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
|
||||
else:
|
||||
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
|
||||
self.bn2 = _bn(channel)
|
||||
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
|
||||
self.bn3 = _bn_last(out_channel)
|
||||
if self.se_block:
|
||||
self.se_global_pool = P.ReduceMean(keep_dims=False)
|
||||
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
|
||||
self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se)
|
||||
self.se_sigmoid = nn.Sigmoid()
|
||||
self.se_mul = P.Mul()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.down_sample = False
|
||||
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
if self.use_se:
|
||||
if stride == 1:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
|
||||
stride, use_se=self.use_se), _bn(out_channel)])
|
||||
else:
|
||||
self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
|
||||
_conv1x1(in_channel, out_channel, 1,
|
||||
use_se=self.use_se), _bn(out_channel)])
|
||||
else:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
|
||||
use_se=self.use_se), _bn(out_channel)])
|
||||
|
||||
def construct(self, x):
|
||||
"""ResidualBlock construct"""
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
if self.use_se and self.stride != 1:
|
||||
out = self.e2(out)
|
||||
else:
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
if self.se_block:
|
||||
out_se = out
|
||||
out = self.se_global_pool(out, (2, 3))
|
||||
out = self.se_dense_0(out)
|
||||
out = self.relu(out)
|
||||
out = self.se_dense_1(out)
|
||||
out = self.se_sigmoid(out)
|
||||
out = F.reshape(out, F.shape(out) + (1, 1))
|
||||
out = self.se_mul(out, out_se)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = out + identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
|
||||
Args:
|
||||
block (Cell): Block for network.
|
||||
layer_nums (list): Numbers of block in different layers.
|
||||
in_channels (list): Input channel in each layer.
|
||||
out_channels (list): Output channel in each layer.
|
||||
strides (list): Stride size in each layer.
|
||||
num_classes (int): The number of classes that the training images are belonging to.
|
||||
use_se (bool): enable SE-ResNet50 net. Default: False.
|
||||
se_block(bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes,
|
||||
use_se=False):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
self.use_se = use_se
|
||||
self.se_block = False
|
||||
if self.use_se:
|
||||
self.se_block = True
|
||||
|
||||
if self.use_se:
|
||||
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
|
||||
self.bn1_0 = _bn(32)
|
||||
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
|
||||
self.bn1_1 = _bn(32)
|
||||
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
|
||||
else:
|
||||
self.conv1 = _conv7x7(3, 64, stride=2)
|
||||
self.bn1 = _bn(64)
|
||||
self.relu = P.ReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0],
|
||||
use_se=self.use_se)
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1],
|
||||
use_se=self.use_se)
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2],
|
||||
use_se=self.use_se,
|
||||
se_block=self.se_block)
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3],
|
||||
use_se=self.use_se,
|
||||
se_block=self.se_block)
|
||||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
|
||||
"""
|
||||
Make stage network of ResNet.
|
||||
|
||||
Args:
|
||||
block (Cell): Resnet block.
|
||||
layer_num (int): Layer number.
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer.
|
||||
se_block(bool): use se block in SE-ResNet50 net. Default: False.
|
||||
Returns:
|
||||
SequentialCell, the output layer.
|
||||
|
||||
"""
|
||||
layers = []
|
||||
|
||||
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
if se_block:
|
||||
for _ in range(1, layer_num - 1):
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
|
||||
layers.append(resnet_block)
|
||||
else:
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""ResNet construct"""
|
||||
if self.use_se:
|
||||
x = self.conv1_0(x)
|
||||
x = self.bn1_0(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_1(x)
|
||||
x = self.bn1_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_2(x)
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
feature1 = c5
|
||||
out = self.mean(c5, (2, 3))
|
||||
|
||||
out = self.flatten(out)
|
||||
|
||||
feature2 = out
|
||||
out = self.end_point(out)
|
||||
|
||||
return out, feature1, feature2
|
||||
|
||||
|
||||
def resnet50(class_num=10):
|
||||
"""
|
||||
Get ResNet50 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet50 neural network.
|
||||
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
||||
|
||||
|
||||
def se_resnet50(class_num=1001):
|
||||
"""
|
||||
Get SE-ResNet50 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of SE-ResNet50 neural network.
|
||||
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num,
|
||||
use_se=True)
|
||||
|
||||
|
||||
def resnet101(class_num=1001):
|
||||
"""
|
||||
Get ResNet101 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet101 neural network.
|
||||
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 23, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ntsnet train."""
|
||||
import ast
|
||||
import math
|
||||
import os
|
||||
import argparse
|
||||
from mindspore.train.callback import CheckpointConfig, TimeMonitor
|
||||
from mindspore import context, nn, Tensor, set_seed, Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from src.config import config
|
||||
from src.dataset import create_dataset_train
|
||||
from src.lr_generator import get_lr
|
||||
from src.network import NTS_NET, WithLossCell, LossCallBack, ModelCheckpoint
|
||||
|
||||
parser = argparse.ArgumentParser(description='ntsnet train running')
|
||||
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default is false.")
|
||||
parser.add_argument('--data_url', default=None,
|
||||
help='Directory contains resnet50.ckpt and CUB_200_2011 dataset.')
|
||||
parser.add_argument('--train_url', default=None, help='Directory of training output.')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
args = parser.parse_args()
|
||||
run_modelart = args.run_modelart
|
||||
if run_modelart:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
local_input_url = '/cache/data' + str(device_id)
|
||||
local_output_url = '/cache/ckpt' + str(device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
|
||||
save_graphs=False)
|
||||
context.set_context(device_id=device_id)
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
global_rank=device_id,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
rank = get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
import moxing as mox
|
||||
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_input_url)
|
||||
elif args.run_distribute:
|
||||
device_id = args.device_id
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
context.set_context(device_id=device_id)
|
||||
init()
|
||||
device_num = get_group_size()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
local_input_url = args.data_url
|
||||
local_output_url = args.train_url
|
||||
rank = get_rank()
|
||||
else:
|
||||
device_id = args.device_id
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
context.set_context(device_id=device_id)
|
||||
rank = 0
|
||||
device_num = 1
|
||||
local_input_url = args.data_url
|
||||
local_output_url = args.train_url
|
||||
|
||||
learning_rate = config.learning_rate
|
||||
momentum = config.momentum
|
||||
weight_decay = config.weight_decay
|
||||
batch_size = config.batch_size
|
||||
num_train_images = config.num_train_images
|
||||
num_epochs = config.num_epochs
|
||||
steps_per_epoch = math.ceil(num_train_images / batch_size)
|
||||
lr = Tensor(get_lr(global_step=0,
|
||||
lr_init=0,
|
||||
lr_max=learning_rate,
|
||||
warmup_epochs=4,
|
||||
total_epochs=num_epochs,
|
||||
steps_per_epoch=steps_per_epoch))
|
||||
|
||||
if __name__ == '__main__':
|
||||
set_seed(1)
|
||||
resnet50Path = os.path.join(local_input_url, "resnet50.ckpt")
|
||||
ntsnet = NTS_NET(topK=6, resnet50Path=resnet50Path)
|
||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
||||
optimizer = nn.SGD(ntsnet.trainable_params(), learning_rate=lr, momentum=momentum, weight_decay=weight_decay)
|
||||
loss_net = WithLossCell(ntsnet, loss_fn)
|
||||
oneStepNTSNet = nn.TrainOneStepCell(loss_net, optimizer)
|
||||
|
||||
train_data_set = create_dataset_train(train_path=os.path.join(local_input_url, "CUB_200_2011/train"),
|
||||
batch_size=batch_size)
|
||||
dataset_size = train_data_set.get_batch_size()
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
|
||||
loss_cb = LossCallBack(rank_id=rank, local_output_url=local_output_url, device_num=device_num, device_id=device_id,
|
||||
args=args, run_modelart=run_modelart)
|
||||
cb = [time_cb, loss_cb]
|
||||
|
||||
if config.save_checkpoint and rank == 0:
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * steps_per_epoch,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_checkpoint_path = os.path.join(local_output_url, "ckpt_" + str(rank) + "/")
|
||||
|
||||
ckpoint_cb = ModelCheckpoint(prefix=config.prefix, directory=save_checkpoint_path, ckconfig=ckptconfig,
|
||||
device_num=device_num, device_id=device_id, args=args, run_modelart=run_modelart)
|
||||
cb += [ckpoint_cb]
|
||||
|
||||
model = Model(oneStepNTSNet, amp_level="O3", keep_batchnorm_fp32=False)
|
||||
model.train(config.num_epochs, train_data_set, callbacks=cb, dataset_sink_mode=True)
|
Loading…
Reference in New Issue