adding GPU mode and CPU mode

This commit is contained in:
huangbo77 2021-04-16 16:04:16 +08:00
parent 17e42347fc
commit 2732d7333a
14 changed files with 656 additions and 392 deletions

View File

@ -13,7 +13,7 @@
# [Face Recognition For Tracking Description](#contents) # [Face Recognition For Tracking Description](#contents)
This is a face recognition for tracking network based on Resnet, with support for training and evaluation on Ascend910. This is a face recognition for tracking network based on Resnet, with support for training and evaluation on Ascend910, GPU and CPU.
ResNet (residual neural network) was proposed by Kaiming He and other four Chinese of Microsoft Research Institute. Through the use of ResNet unit, it successfully trained 152 layers of neural network, and won the championship in ilsvrc2015. The error rate on top 5 was 3.57%, and the parameter quantity was lower than vggnet, so the effect was very outstanding. Traditional convolution network or full connection network will have more or less information loss. At the same time, it will lead to the disappearance or explosion of gradient, which leads to the failure of deep network training. ResNet solves this problem to a certain extent. By passing the input information to the output, the integrity of the information is protected. The whole network only needs to learn the part of the difference between input and output, which simplifies the learning objectives and difficulties.The structure of ResNet can accelerate the training of neural network very quickly, and the accuracy of the model is also greatly improved. At the same time, ResNet is very popular, even can be directly used in the concept net network. ResNet (residual neural network) was proposed by Kaiming He and other four Chinese of Microsoft Research Institute. Through the use of ResNet unit, it successfully trained 152 layers of neural network, and won the championship in ilsvrc2015. The error rate on top 5 was 3.57%, and the parameter quantity was lower than vggnet, so the effect was very outstanding. Traditional convolution network or full connection network will have more or less information loss. At the same time, it will lead to the disappearance or explosion of gradient, which leads to the failure of deep network training. ResNet solves this problem to a certain extent. By passing the input information to the output, the integrity of the information is protected. The whole network only needs to learn the part of the difference between input and output, which simplifies the learning objectives and difficulties.The structure of ResNet can accelerate the training of neural network very quickly, and the accuracy of the model is also greatly improved. At the same time, ResNet is very popular, even can be directly used in the concept net network.
@ -55,7 +55,7 @@ The directory structure is as follows:
# [Environment Requirements](#contents) # [Environment Requirements](#contents)
- Hardware(Ascend) - Hardware(Ascend/GPU/CPU)
- Prepare hardware environment with Ascend processor. - Prepare hardware environment with Ascend processor.
- Framework - Framework
- [MindSpore](https://www.mindspore.cn/install/en) - [MindSpore](https://www.mindspore.cn/install/en)
@ -77,19 +77,25 @@ The entire code structure is as following:
├─ run_standalone_train.sh # launch standalone training(1p) in ascend ├─ run_standalone_train.sh # launch standalone training(1p) in ascend
├─ run_distribute_train.sh # launch distributed training(8p) in ascend ├─ run_distribute_train.sh # launch distributed training(8p) in ascend
├─ run_eval.sh # launch evaluating in ascend ├─ run_eval.sh # launch evaluating in ascend
└─ run_export.sh # launch exporting air model ├─ run_export.sh # launch exporting air/mindir model
├─ run_standalone_train_gpu.sh # launch standalone training(1p) in gpu
├─ run_distribute_train_gpu.sh # launch distributed training(8p) in gpu
├─ run_eval_gpu.sh # launch evaluating in gpu
├─ run_export_gpu.sh # launch exporting mindir model in gpu
├─ run_train_cpu.sh # launch standalone training in cpu
├─ run_eval_cpu.sh # launch evaluating in cpu
└─ run_export_cpu.sh # launch exporting mindir model in cpu
├─ src ├─ src
├─ config.py # parameter configuration ├─ config.py # parameter configuration
├─ dataset.py # dataset loading and preprocessing for training ├─ dataset.py # dataset loading and preprocessing for training
├─ reid.py # network backbone ├─ reid.py # network backbone
├─ reid_for_export.py # network backbone for export
├─ log.py # log function ├─ log.py # log function
├─ loss.py # loss function ├─ loss.py # loss function
├─ lr_generator.py # generate learning rate ├─ lr_generator.py # generate learning rate
└─ me_init.py # network initialization └─ me_init.py # network initialization
├─ train.py # training scripts ├─ train.py # training scripts
├─ eval.py # evaluation scripts ├─ eval.py # evaluation scripts
└─ export.py # export air model └─ export.py # export air/mindir model
``` ```
## [Running Example](#contents) ## [Running Example](#contents)
@ -99,18 +105,50 @@ The entire code structure is as following:
- Stand alone mode - Stand alone mode
```bash ```bash
Ascend:
cd ./scripts cd ./scripts
sh run_standalone_train.sh [DATA_DIR] [USE_DEVICE_ID] sh run_standalone_train.sh [DATA_DIR] [USE_DEVICE_ID]
``` ```
```bash
GPU:
cd ./scripts
sh run_standalone_train_gpu.sh [DATA_DIR]
```
```bash
CPU:
cd ./scripts
sh run_train_cpu.sh [DATA_DIR]
```
or (fine-tune) or (fine-tune)
```bash ```bash
Ascend:
cd ./scripts cd ./scripts
sh run_standalone_train.sh [DATA_DIR] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] sh run_standalone_train.sh [DATA_DIR] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
``` ```
for example: ```bash
GPU:
cd ./scripts
sh run_standalone_train.sh [DATA_DIR] [PRETRAINED_BACKBONE]
```
```bash
CPU:
cd ./scripts
sh run_train.sh [DATA_DIR] [PRETRAINED_BACKBONE]
```
for example, on Ascend:
```bash ```bash
cd ./scripts cd ./scripts
@ -120,17 +158,35 @@ The entire code structure is as following:
- Distribute mode (recommended) - Distribute mode (recommended)
```bash ```bash
Ascend:
cd ./scripts cd ./scripts
sh run_distribute_train.sh [DATA_DIR] [RANK_TABLE] sh run_distribute_train.sh [DATA_DIR] [RANK_TABLE]
``` ```
```bash
GPU:
cd ./scripts
sh run_distribute_train_gpu.sh [DEVICE_NUM] [VISIBLE_DEVICES(0, 1, 2, 3, 4, 5, 6, 7)] [DATASET_PATH]
```
or (fine-tune) or (fine-tune)
```bash ```bash
Ascend:
cd ./scripts cd ./scripts
sh run_distribute_train.sh [DATA_DIR] [RANK_TABLE] [PRETRAINED_BACKBONE] sh run_distribute_train.sh [DATA_DIR] [RANK_TABLE] [PRETRAINED_BACKBONE]
``` ```
```bash
GPU:
cd ./scripts
sh run_distribute_train_gpu.sh [DEVICE_NUM] [VISIBLE_DEVICES(0, 1, 2, 3, 4, 5, 6, 7)] [DATASET_PATH] [PRE_TRAINED]
```
for example: for example:
```bash ```bash
@ -156,11 +212,27 @@ epoch[179], iter[14930], loss:1.694281, 13417.38 imgs/sec, lr=0.0250000003725290
### Evaluation ### Evaluation
```bash ```bash
Ascend:
cd ./scripts cd ./scripts
sh run_eval.sh [EVAL_DIR] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] sh run_eval.sh [EVAL_DIR] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
``` ```
for example: ```bash
GPU:
cd ./scripts
sh run_eval_gpu.sh [EVAL_DIR] [PRETRAINED_BACKBONE]
```
```bash
CPU:
cd ./scripts
sh run_eval_cpu.sh [EVAL_DIR] [PRETRAINED_BACKBONE]
```
for example, on Ascend:
```bash ```bash
cd ./scripts cd ./scripts
@ -184,44 +256,62 @@ You will get the result as following in "./scripts/device0/eval.log" or txt file
If you want to infer the network on Ascend 310, you should convert the model to AIR: If you want to infer the network on Ascend 310, you should convert the model to AIR:
```bash ```bash
Ascend:
cd ./scripts cd ./scripts
sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
``` ```
Or if you would like to convert your model to MINDIR file on GPU or CPU:
```bash
GPU:
cd ./scripts
sh run_export_gpu.sh [PRETRAINED_BACKBONE] [BATCH_SIZE] [FILE_NAME](optional)
```
```bash
CPU:
cd ./scripts
sh run_export_cpu.sh [PRETRAINED_BACKBONE] [BATCH_SIZE] [FILE_NAME](optional)
```
# [Model Description](#contents) # [Model Description](#contents)
## [Performance](#contents) ## [Performance](#contents)
### Training Performance ### Training Performance
| Parameters | Face Recognition For Tracking | | Parameters | Ascend |GPU |CPU |
| -------------------------- | ----------------------------------------------------------- | | -------------------------- | ----------------------------------------------------------- | ----------------------------------------------------------- | ----------------------------------------------------------- |
| Model Version | V1 | | Model Version | V1 | V1 | V1 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | | Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G; OS Euler2.8 |Tesla V100-PCIE |Intel(R) Xeon(R) CPU E5-2690 v4 |
| uploaded Date | 09/30/2020 (month/day/year) | | uploaded Date | 09/30/2020 (month/day/year) |04/17/2021 (month/day/year) |04/17/2021 (month/day/year) |
| MindSpore Version | 1.0.0 | | MindSpore Version | 1.0.0 | 1.2.0 |1.2.0 |
| Dataset | 10K images | | Dataset | 10K images | 10K images | 10K images |
| Training Parameters | epoch=180, batch_size=16, momentum=0.9 | | Training Parameters | epoch=180, batch_size=16, momentum=0.9 | epoch=40, batch_size=128(1p); 16(8p), momentum=0.9 | epoch=40, batch_size=128, momentum=0.9 |
| Optimizer | Momentum | | Optimizer | SGD | SGD | SGD |
| Loss Function | Softmax Cross Entropy | | Loss Function | Softmax Cross Entropy | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | | outputs | probability | probability |probability |
| Speed | 1pc: 8~10 ms/step; 8pcs: 9~11 ms/step | | Speed | 1pc: 8-10 ms/step; 8pcs: 9-11 ms/step | 1pc: 30 ms/step; 8pcs: 20 ms/step | 1pc: 2.5 s/step |
| Total time | 1pc: 1 hours; 8pcs: 0.1 hours | | Total time | 1pc: 1 hour; 8pcs: 0.1 hours | 1pc: 2 minutes; 8pcs: 1.5 minutes |1pc: 2 hours |
| Checkpoint for Fine tuning | 17M (.ckpt file) | | Checkpoint for Fine tuning | 17M (.ckpt file) | 17M (.ckpt file) | 17M (.ckpt file) |
### Evaluation Performance ### Evaluation Performance
| Parameters |Face Recognition For Tracking| | Parameters |Ascend |GPU |CPU |
| ------------------- | --------------------------- | | ------------------- | --------------------------- | --------------------------- | --------------------------- |
| Model Version | V1 | | Model Version |V1 |V1 |V1 |
| Resource | Ascend 910; OS Euler2.8 | | Resource | Ascend 910; OS Euler2.8 |Tesla V100-PCIE |Intel(R) Xeon(R) CPU E5-2690 v4 |
| Uploaded Date | 09/30/2020 (month/day/year) | | Uploaded Date | 09/30/2020 (month/day/year) | 04/17/2021 (month/day/year) | 04/17/2021 (month/day/year) |
| MindSpore Version | 1.0.0 | | MindSpore Version | 1.0.0 | 1.2.0 |1.2.0 |
| Dataset | 2K images | | Dataset | 2K images | 2K images | 2K images |
| batch_size | 128 | | batch_size | 128 | 128 |128 |
| outputs | recall | | outputs | recall | recall |recall |
| Recall(8pcs) | 0.62(FAR=0.1) | | Recall | 0.62(FAR=0.1) | 0.62(FAR=0.1) | 0.62(FAR=0.1) |
| Model for inference | 17M (.ckpt file) | | Model for inference | 17M (.ckpt file) | 17M (.ckpt file) | 17M (.ckpt file) |
# [ModelZoo Homepage](#contents) # [ModelZoo Homepage](#contents)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -29,8 +29,6 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.reid import SphereNet from src.reid import SphereNet
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid)
def inclass_likehood(ims_info, types='cos'): def inclass_likehood(ims_info, types='cos'):
@ -135,7 +133,10 @@ def main(args):
else: else:
print('-----------------------load model failed -----------------------') print('-----------------------load model failed -----------------------')
network.add_flags_recursive(fp16=True) if args.device_target == 'CPU':
network.add_flags_recursive(fp32=True)
else:
network.add_flags_recursive(fp16=True)
network.set_train(False) network.set_train(False)
root_path = args.eval_dir root_path = args.eval_dir
@ -178,8 +179,15 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='reid test') parser = argparse.ArgumentParser(description='reid test')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--eval_dir', type=str, default='', help='eval image dir, e.g. /home/test') parser.add_argument('--eval_dir', type=str, default='', help='eval image dir, e.g. /home/test')
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
help='device_target')
arg = parser.parse_args() arg = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=arg.device_target, save_graphs=False)
if arg.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)
print(arg) print(arg)
main(arg) main(arg)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Convert ckpt to air.""" """Convert ckpt to air/mindir."""
import os import os
import argparse import argparse
import numpy as np import numpy as np
@ -21,14 +21,11 @@ from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.reid_for_export import SphereNet from src.reid import SphereNet_float32
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid)
def main(args): def main(args):
network = SphereNet(num_layers=12, feature_dim=128, shape=(96, 64)) network = SphereNet_float32(num_layers=12, feature_dim=128, shape=(96, 64))
ckpt_path = args.pretrained ckpt_path = args.pretrained
if os.path.isfile(ckpt_path): if os.path.isfile(ckpt_path):
param_dict = load_checkpoint(ckpt_path) param_dict = load_checkpoint(ckpt_path)
@ -45,23 +42,35 @@ def main(args):
else: else:
print('-----------------------load model failed -----------------------') print('-----------------------load model failed -----------------------')
network.add_flags_recursive(fp16=True) if args.device_target == 'CPU':
network.add_flags_recursive(fp32=True)
else:
network.add_flags_recursive(fp16=True)
network.set_train(False) network.set_train(False)
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 96, 64)).astype(np.float32) input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 96, 64)).astype(np.float32)
tensor_input_data = Tensor(input_data) tensor_input_data = Tensor(input_data)
export(network, tensor_input_data, file_name=ckpt_path.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'), export(network, tensor_input_data, file_name=args.file_name, file_format=args.file_format)
file_format='AIR')
print('-----------------------export model success-----------------------') print('-----------------------export model success-----------------------')
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Convert ckpt to air') parser = argparse.ArgumentParser(description='Convert ckpt to air/mindir')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--batch_size', type=int, default=8, help='batch size') parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
help='device_target')
parser.add_argument('--file_name', type=str, default='FaceRecognitionForTracking', help='output file name')
parser.add_argument('--file_format', type=str, choices=['AIR', 'ONNX', 'MINDIR'], default='AIR', help='file format')
arg = parser.parse_args() arg = parser.parse_args()
if arg.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)
context.set_context(mode=context.GRAPH_MODE, device_target=arg.device_target)
main(arg) main(arg)

View File

@ -0,0 +1,59 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 3 ]
then
echo "Usage: sh run_distributed_train_gpu.sh [DEVICE_NUM] [VISIBLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
[PRE_TRAINED](optional)"
exit 1
fi
if [ $1 -lt 1 ] && [ $1 -gt 8 ]
then
echo "error: DEVICE_NUM=$1 is not in (1-8)"
exit 1
fi
export DEVICE_NUM=$1
export RANK_SIZE=$1
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ]
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
export CUDA_VISIBLE_DEVICES="$2"
if [ $4 ] #pretrained ckpt
then
if [ $1 -gt 1 ]
then
mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \
--data_dir=$3 \
--is_distributed=1 \
--device_target='GPU'
else
python3 ${BASEPATH}/../train.py \
--data_dir=$3 \
--is_distributed=0 \
--device_target='GPU'
fi
fi

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 2 ]
then
echo "Usage: sh run_eval_cpu.sh [EVALDATA_PATH] [PRE_TRAINED]"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
python3 ${BASEPATH}/../eval.py \
--eval_dir=$1 \
--device_target='CPU' \
--pretrained=$2

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 2 ]
then
echo "Usage: sh run_eval_gpu.sh [EVALDATA_PATH] [PRE_TRAINED]"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
python3 ${BASEPATH}/../eval.py \
--eval_dir=$1 \
--device_target='GPU' \
--pretrained=$2

View File

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 2 ]
then
echo "Usage: sh run_export_cpu.sh [PRE_TRAINED] [BATCH_SIZE] [FILE_NAME](optional)"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
cd ..
if [ $3 ] #file name
then
python3 ${BASEPATH}/../export.py \
--pretrained=$1 \
--device_target='CPU' \
--batch_size=$2 \
--file_format=MINDIR \
--file_name=$3
else
python3 ${BASEPATH}/../export.py \
--pretrained=$1 \
--device_target='CPU' \
--batch_size=$2 \
--file_format=MINDIR
fi

View File

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 2 ]
then
echo "Usage: sh run_export_gpu.sh [PRE_TRAINED] [BATCH_SIZE] [FILE_NAME](optional)"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
cd ..
if [ $3 ] #file name
then
python3 ${BASEPATH}/../export.py \
--pretrained=$1 \
--device_target='GPU' \
--batch_size=$2 \
--file_format=MINDIR \
--file_name=$3
else
python3 ${BASEPATH}/../export.py \
--pretrained=$1 \
--device_target='GPU' \
--batch_size=$2 \
--file_format=MINDIR
fi

View File

@ -0,0 +1,43 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 1 ]
then
echo "Usage: sh run_standalone_train_gpu.sh [DATASET_PATH] [PRE_TRAINED] (optional)"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ]
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
if [ $2 ] #pretrained ckpt
then
python3 ${BASEPATH}/../train.py \
--data_dir=$1 \
--device_target='GPU' \
--pretrained=$2
else
python3 ${BASEPATH}/../train.py \
--data_dir=$1 \
--device_target='GPU'
fi

View File

@ -0,0 +1,43 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 1 ]
then
echo "Usage: sh run_train_cpu.sh [DATASET_PATH] [PRE_TRAINED](optional)"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ]
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
if [ $2 ] #pretrained ckpt
then
python3 ${BASEPATH}/../train.py \
--data_dir=$1 \
--device_target='CPU' \
--pretrained=$2
else
python3 ${BASEPATH}/../train.py \
--data_dir=$1 \
--device_target='CPU'
fi

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,8 +15,8 @@
"""Network config setting, will be used in train.py and eval.py""" """Network config setting, will be used in train.py and eval.py"""
from easydict import EasyDict as edict from easydict import EasyDict as edict
reid_1p_cfg = edict({ reid_1p_cfg_ascend = edict({
'task': 'REID_1p', 'task': 'REID_1p_ascend',
# dataset related # dataset related
'per_batch_size': 128, 'per_batch_size': 128,
@ -52,8 +52,8 @@ reid_1p_cfg = edict({
}) })
reid_8p_cfg = edict({ reid_8p_cfg_ascend = edict({
'task': 'REID_8p', 'task': 'REID_8p_ascend',
# dataset related # dataset related
'per_batch_size': 16, 'per_batch_size': 16,
@ -87,3 +87,76 @@ reid_8p_cfg = edict({
'ckpt_path': '../../output', 'ckpt_path': '../../output',
'ckpt_interval': 200, 'ckpt_interval': 200,
}) })
reid_1p_cfg = edict({
'task': 'REID_1p',
# dataset related
'per_batch_size': 128,
# network structure related
'fp16': 1,
'loss_scale': 2048.0,
'input_size': (96, 64),
'net_depth': 12,
'embedding_size': 128,
# optimizer related
'lr': 0.1,
'lr_scale': 1,
'lr_gamma': 1,
'lr_epochs': '30,60,120,150',
'epoch_size': 30,
'warmup_epochs': 0,
'steps_per_epoch': 0,
'max_epoch': 40,
'weight_decay': 0.0005,
'momentum': 0.9,
# distributed parameter
'is_distributed': 0,
'local_rank': 0,
'world_size': 1,
# logging related
'log_interval': 10,
'ckpt_path': '../output',
'ckpt_interval': 200,
})
reid_8p_cfg_gpu = edict({
'task': 'REID_8p_gpu',
# dataset related
'per_batch_size': 16,
# network structure related
'fp16': 1,
'loss_scale': 2048.0,
'input_size': (96, 64),
'net_depth': 12,
'embedding_size': 128,
# optimizer related
'lr': 0.1,
'lr_scale': 1,
'lr_gamma': 1,
'lr_epochs': '30,60,120,150',
'epoch_size': 30,
'warmup_epochs': 0,
'steps_per_epoch': 0,
'max_epoch': 40,
'weight_decay': 0.0005,
'momentum': 0.9,
# distributed parameter
'is_distributed': 1,
'local_rank': 0,
'world_size': 8,
# logging related
'log_interval': 10,
'ckpt_path': '../output',
'ckpt_interval': 200,
})

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -23,18 +23,11 @@ from mindspore.nn import Dense, Cell
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore import context
from src import me_init from src import me_init
class Cut(nn.Cell):
def construct(self, x):
return x
def bn_with_initialize(out_channels): def bn_with_initialize(out_channels):
bn = nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5).add_flags_recursive(fp32=True) bn = nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5).add_flags_recursive(fp32=True)
return bn return bn
@ -77,6 +70,7 @@ class BaseBlock(Cell):
self.cast = P.Cast() self.cast = P.Cast()
self.add = Add() self.add = Add()
self.device_target = context.get_context('device_target')
def construct(self, x): def construct(self, x):
'''Construct function.''' '''Construct function.'''
@ -88,8 +82,9 @@ class BaseBlock(Cell):
out = self.bn2(out) out = self.bn2(out)
out = self.relu2(out) out = self.relu2(out)
# hand cast # hand cast
identity = self.cast(identity, mstype.float16) if self.device_target != 'CPU':
out = self.cast(out, mstype.float16) identity = self.cast(identity, mstype.float16)
out = self.cast(out, mstype.float16)
out = self.add(out, identity) out = self.add(out, identity)
return out return out
@ -143,7 +138,6 @@ class SphereNet(Cell):
raise ValueError('sphere' + str(num_layers) + " IS NOT SUPPORTED! (sphere20 or sphere64)") raise ValueError('sphere' + str(num_layers) + " IS NOT SUPPORTED! (sphere20 or sphere64)")
self.shape = P.Shape() self.shape = P.Shape()
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.arg_shape = shape
block = BaseBlock block = BaseBlock
self.layer1 = MakeLayer(block, filter_list[0], filter_list[1], layers[0], stride=2) self.layer1 = MakeLayer(block, filter_list[0], filter_list[1], layers[0], stride=2)
@ -153,6 +147,7 @@ class SphereNet(Cell):
self.fc = fc_with_initialize(fc_size, feature_dim) self.fc = fc_with_initialize(fc_size, feature_dim)
self.last_bn = nn.BatchNorm1d(feature_dim, momentum=0.9).add_flags_recursive(fp32=True) self.last_bn = nn.BatchNorm1d(feature_dim, momentum=0.9).add_flags_recursive(fp32=True)
self.last_bn_sub = nn.BatchNorm2d(feature_dim, momentum=0.9).add_flags_recursive(fp32=True)
self.cast = P.Cast() self.cast = P.Cast()
self.l2norm = P.L2Normalize(axis=1) self.l2norm = P.L2Normalize(axis=1)
@ -164,6 +159,7 @@ class SphereNet(Cell):
cell.bias.set_data(initializer('zeros', cell.bias.shape)) cell.bias.set_data(initializer('zeros', cell.bias.shape))
else: else:
cell.weight.set_data(initializer(me_init.ReidXavierUniform(), cell.weight.shape)) cell.weight.set_data(initializer(me_init.ReidXavierUniform(), cell.weight.shape))
self.device_target = context.get_context('device_target')
def construct(self, x): def construct(self, x):
'''Construct function.''' '''Construct function.'''
@ -175,13 +171,99 @@ class SphereNet(Cell):
b, _, _, _ = self.shape(x) b, _, _, _ = self.shape(x)
x = self.reshape(x, (b, -1)) x = self.reshape(x, (b, -1))
x = self.fc(x) x = self.fc(x)
x = self.last_bn(x)
x = self.cast(x, mstype.float16) if self.device_target == 'Ascend':
x = self.last_bn(x)
else:
old_shape = x.shape
x = self.reshape(x, (old_shape[0], old_shape[1], 1, 1))
x = self.last_bn_sub(x)
x = self.reshape(x, old_shape)
if self.device_target != 'CPU':
x = self.cast(x, mstype.float16)
x = self.l2norm(x) x = self.l2norm(x)
return x return x
class SphereNet_float32(Cell):
'''SphereNet_float32'''
def __init__(self, num_layers=36, feature_dim=128, shape=(96, 64)):
super(SphereNet_float32, self).__init__()
assert num_layers in [12, 20, 36, 64], 'SphereNet num_layers should be 12, 20 or 64'
if num_layers == 12:
layers = [1, 1, 1, 1]
filter_list = [3, 16, 32, 64, 128]
fc_size = 128 * 6 * 4
elif num_layers == 20:
layers = [1, 2, 4, 1]
filter_list = [3, 64, 128, 256, 512]
fc_size = 512 * 6 * 4
elif num_layers == 36:
layers = [2, 4, 4, 2]
filter_list = [3, 32, 64, 128, 256]
fc_size = 256 * 6 * 4
elif num_layers == 64:
layers = [3, 7, 16, 3]
filter_list = [3, 64, 128, 256, 512]
fc_size = 512 * 6 * 4
else:
raise ValueError('sphere' + str(num_layers) + " IS NOT SUPPORTED! (sphere20 or sphere64)")
self.shape = P.Shape()
self.reshape = P.Reshape()
block = BaseBlock
self.layer1 = MakeLayer(block, filter_list[0], filter_list[1], layers[0], stride=2)
self.layer2 = MakeLayer(block, filter_list[1], filter_list[2], layers[1], stride=2)
self.layer3 = MakeLayer(block, filter_list[2], filter_list[3], layers[2], stride=2)
self.layer4 = MakeLayer(block, filter_list[3], filter_list[4], layers[3], stride=2)
self.fc = fc_with_initialize(fc_size, feature_dim)
self.last_bn = nn.BatchNorm1d(feature_dim, momentum=0.9).add_flags_recursive(fp32=True)
self.last_bn_sub = nn.BatchNorm2d(feature_dim, momentum=0.9).add_flags_recursive(fp32=True)
self.cast = P.Cast()
self.l2norm = P.L2Normalize(axis=1)
for _, cell in self.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Dense)):
if cell.bias is not None:
cell.weight.set_data(initializer(me_init.ReidKaimingUniform(a=math.sqrt(5), mode='fan_out'),
cell.weight.shape))
cell.bias.set_data(initializer('zeros', cell.bias.shape))
else:
cell.weight.set_data(initializer(me_init.ReidXavierUniform(), cell.weight.shape))
self.device_target = context.get_context('device_target')
def construct(self, x):
'''Construct function.'''
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
b, _, _, _ = self.shape(x)
x = self.reshape(x, (b, -1))
x = self.fc(x)
if self.device_target == 'Ascend':
x = self.last_bn(x)
else:
old_shape = x.shape
x = self.reshape(x, (old_shape[0], old_shape[1], 1, 1))
x = self.last_bn_sub(x)
x = self.reshape(x, old_shape)
if self.device_target != 'CPU':
x = self.cast(x, mstype.float16)
x = self.l2norm(x)
x = self.cast(x, mstype.float32)
return x
class CombineMarginFC(nn.Cell): class CombineMarginFC(nn.Cell):
'''CombineMarginFC''' '''CombineMarginFC'''
def __init__(self, embbeding_size=128, classnum=270762, s=32, a=1.0, m=0.3, b=0.2): def __init__(self, embbeding_size=128, classnum=270762, s=32, a=1.0, m=0.3, b=0.2):
@ -208,12 +290,16 @@ class CombineMarginFC(nn.Cell):
self.cast = P.Cast() self.cast = P.Cast()
self.on_value = Tensor(1.0, mstype.float32) self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32)
self.device_target = context.get_context('device_target')
def construct(self, x, label): def construct(self, x, label):
'''Construct function.''' '''Construct function.'''
w = self.normalize(self.weight) w = self.normalize(self.weight)
cosine = self.fc(self.cast(x, mstype.float16), self.cast(w, mstype.float16)) if self.device_target == 'CPU':
cosine = self.cast(cosine, mstype.float32) cosine = self.fc(x, w)
else:
cosine = self.fc(self.cast(x, mstype.float16), self.cast(w, mstype.float16))
cosine = self.cast(cosine, mstype.float32)
cosine_shape = F.shape(cosine) cosine_shape = F.shape(cosine)
one_hot_float = self.onehot( one_hot_float = self.onehot(

View File

@ -1,310 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition backbone."""
import math
import mindspore.nn as nn
from mindspore.ops.operations import Add
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn import Dense, Cell
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore import Tensor, Parameter
from src import me_init
class Cut(nn.Cell):
def construct(self, x):
return x
def bn_with_initialize(out_channels):
bn = nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5).add_flags_recursive(fp32=True)
return bn
def fc_with_initialize(input_channels, out_channels):
return Dense(input_channels, out_channels)
def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1, bias=True):
"""3x3 convolution with padding"""
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
pad_mode=pad_mode, group=groups, has_bias=bias, dilation=dilation, padding=padding)
def conv1x1(in_channels, out_channels, pad_mode="pad", stride=1, padding=0, bias=True):
"""1x1 convolution"""
return nn.Conv2d(in_channels, out_channels, pad_mode=pad_mode, kernel_size=1, stride=stride, has_bias=bias,
padding=padding)
def conv4x4(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1, bias=True):
"""4x4 convolution with padding"""
return nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride,
pad_mode=pad_mode, group=groups, has_bias=bias, dilation=dilation, padding=padding)
class BaseBlock(Cell):
'''BaseBlock'''
def __init__(self, channels):
super(BaseBlock, self).__init__()
self.conv1 = conv3x3(channels, channels, stride=1, padding=1, bias=False)
self.bn1 = bn_with_initialize(channels)
self.relu1 = P.ReLU()
self.conv2 = conv3x3(channels, channels, stride=1, padding=1, bias=False)
self.bn2 = bn_with_initialize(channels)
self.relu2 = P.ReLU()
self.cast = P.Cast()
self.add = Add()
def construct(self, x):
'''Construct function.'''
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu2(out)
# hand cast
identity = self.cast(identity, mstype.float16)
out = self.cast(out, mstype.float16)
out = self.add(out, identity)
return out
class MakeLayer(Cell):
'''MakeLayer'''
def __init__(self, block, inplanes, planes, blocks, stride=2):
super(MakeLayer, self).__init__()
self.conv = conv3x3(inplanes, planes, stride=stride, padding=1, bias=True)
self.bn = bn_with_initialize(planes)
self.relu = P.ReLU()
self.layers = []
for _ in range(0, blocks):
self.layers.append(block(planes))
self.layers = nn.CellList(self.layers)
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
for block in self.layers:
x = block(x)
return x
class SphereNet(Cell):
'''SphereNet'''
def __init__(self, num_layers=36, feature_dim=128, shape=(96, 64)):
super(SphereNet, self).__init__()
assert num_layers in [12, 20, 36, 64], 'SphereNet num_layers should be 12, 20 or 64'
if num_layers == 12:
layers = [1, 1, 1, 1]
filter_list = [3, 16, 32, 64, 128]
fc_size = 128 * 6 * 4
elif num_layers == 20:
layers = [1, 2, 4, 1]
filter_list = [3, 64, 128, 256, 512]
fc_size = 512 * 6 * 4
elif num_layers == 36:
layers = [2, 4, 4, 2]
filter_list = [3, 32, 64, 128, 256]
fc_size = 256 * 6 * 4
elif num_layers == 64:
layers = [3, 7, 16, 3]
filter_list = [3, 64, 128, 256, 512]
fc_size = 512 * 6 * 4
else:
raise ValueError('sphere' + str(num_layers) + " IS NOT SUPPORTED! (sphere20 or sphere64)")
self.shape = P.Shape()
self.reshape = P.Reshape()
self.arg_shape = shape
block = BaseBlock
self.layer1 = MakeLayer(block, filter_list[0], filter_list[1], layers[0], stride=2)
self.layer2 = MakeLayer(block, filter_list[1], filter_list[2], layers[1], stride=2)
self.layer3 = MakeLayer(block, filter_list[2], filter_list[3], layers[2], stride=2)
self.layer4 = MakeLayer(block, filter_list[3], filter_list[4], layers[3], stride=2)
self.fc = fc_with_initialize(fc_size, feature_dim)
self.last_bn = nn.BatchNorm1d(feature_dim, momentum=0.9).add_flags_recursive(fp32=True)
self.cast = P.Cast()
self.l2norm = P.L2Normalize(axis=1)
for _, cell in self.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Dense)):
if cell.bias is not None:
cell.weight.set_data(initializer(me_init.ReidKaimingUniform(a=math.sqrt(5), mode='fan_out'),
cell.weight.shape))
cell.bias.set_data(initializer('zeros', cell.bias.shape))
else:
cell.weight.set_data(initializer(me_init.ReidXavierUniform(), cell.weight.shape))
def construct(self, x):
'''Construct function.'''
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
b, _, _, _ = self.shape(x)
x = self.reshape(x, (b, -1))
x = self.fc(x)
x = self.last_bn(x)
x = self.cast(x, mstype.float16)
x = self.l2norm(x)
x = self.cast(x, mstype.float32)
return x
class CombineMarginFC(nn.Cell):
'''CombineMarginFC'''
def __init__(self, embbeding_size=128, classnum=270762, s=32, a=1.0, m=0.3, b=0.2):
super(CombineMarginFC, self).__init__()
weight_shape = [classnum, embbeding_size]
weight_init = initializer(me_init.ReidXavierUniform(), weight_shape)
self.weight = Parameter(weight_init, name='weight')
self.m = m
self.s = s
self.a = a
self.b = b
self.m_const = Tensor(self.m, dtype=mstype.float32)
self.a_const = Tensor(self.a, dtype=mstype.float32)
self.b_const = Tensor(self.b, dtype=mstype.float32)
self.s_const = Tensor(self.s, dtype=mstype.float32)
self.m_const_zero = Tensor(0.0, dtype=mstype.float32)
self.a_const_one = Tensor(1.0, dtype=mstype.float32)
self.normalize = P.L2Normalize(axis=1)
self.fc = P.MatMul(transpose_b=True)
self.onehot = P.OneHot()
self.transpose = P.Transpose()
self.acos = P.ACos()
self.cos = P.Cos()
self.cast = P.Cast()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
def construct(self, x, label):
'''Construct function.'''
w = self.normalize(self.weight)
cosine = self.fc(self.cast(x, mstype.float16), self.cast(w, mstype.float16))
cosine = self.cast(cosine, mstype.float32)
cosine_shape = F.shape(cosine)
one_hot_float = self.onehot(
self.cast(label, mstype.int32), cosine_shape[1], self.on_value, self.off_value)
theta = self.acos(cosine)
theta = self.a_const * theta
theta = self.m_const + theta
body = self.cos(theta)
body = body - self.b_const
cos_mask = F.scalar_to_array(1.0) - one_hot_float
output = body * one_hot_float + cosine * cos_mask
output = output * self.s_const
return output, cosine
class CombineMarginFCFp16(nn.Cell):
'''CombineMarginFCFp16'''
def __init__(self, embbeding_size=128, classnum=270762, s=32, a=1.0, m=0.3, b=0.2):
super(CombineMarginFCFp16, self).__init__()
weight_shape = [classnum, embbeding_size]
weight_init = initializer(me_init.ReidXavierUniform(), weight_shape)
self.weight = Parameter(weight_init, name='weight')
self.m = m
self.s = s
self.a = a
self.b = b
self.m_const = Tensor(self.m, dtype=mstype.float16)
self.a_const = Tensor(self.a, dtype=mstype.float16)
self.b_const = Tensor(self.b, dtype=mstype.float16)
self.s_const = Tensor(self.s, dtype=mstype.float16)
self.m_const_zero = Tensor(0, dtype=mstype.float16)
self.a_const_one = Tensor(1, dtype=mstype.float16)
self.normalize = P.L2Normalize(axis=1)
self.fc = P.MatMul(transpose_b=True)
self.onehot = P.OneHot()
self.transpose = P.Transpose()
self.acos = P.ACos()
self.cos = P.Cos()
self.cast = P.Cast()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
def construct(self, x, label):
'''Construct function.'''
w = self.normalize(self.weight)
cosine = self.fc(x, w)
cosine_shape = F.shape(cosine)
one_hot_float = self.onehot(
self.cast(label, mstype.int32), cosine_shape[1], self.on_value, self.off_value)
one_hot_float = self.cast(one_hot_float, mstype.float16)
theta = self.acos(cosine)
theta = self.a_const * theta
theta = self.m_const + theta
body = self.cos(theta)
body = body - self.b_const
cos_mask = self.cast(F.scalar_to_array(1.0), mstype.float16) - one_hot_float
output = body * one_hot_float + cosine * cos_mask
output = output * self.s_const
return output, cosine
class BuildTrainNetwork(Cell):
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
def construct(self, input_data, label):
output = self.network(input_data)
loss = self.criterion(output, label)
return loss
class BuildTrainNetworkWithHead(nn.Cell):
'''Build TrainNetwork With Head.'''
def __init__(self, model, head, criterion):
super(BuildTrainNetworkWithHead, self).__init__()
self.model = model
self.head = head
self.criterion = criterion
def construct(self, input_data, labels):
embeddings = self.model(input_data)
thetas, _ = self.head(embeddings, labels)
loss = self.criterion(thetas, labels)
return loss

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -32,31 +32,45 @@ from mindspore.nn import TrainOneStepCell
from mindspore.communication.management import get_group_size, init, get_rank from mindspore.communication.management import get_group_size, init, get_rank
from src.dataset import get_de_dataset from src.dataset import get_de_dataset
from src.config import reid_1p_cfg, reid_8p_cfg from src.config import reid_1p_cfg_ascend, reid_1p_cfg, reid_8p_cfg_ascend, reid_8p_cfg_gpu
from src.lr_generator import step_lr from src.lr_generator import step_lr
from src.log import get_logger, AverageMeter from src.log import get_logger, AverageMeter
from src.reid import SphereNet, CombineMarginFCFp16, BuildTrainNetworkWithHead from src.reid import SphereNet, CombineMarginFCFp16, BuildTrainNetworkWithHead, CombineMarginFC
from src.loss import CrossEntropy from src.loss import CrossEntropy
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid)
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
def init_argument(): def init_argument():
"""init config argument.""" """init config argument."""
parser = argparse.ArgumentParser(description='Cifar10 classification') parser = argparse.ArgumentParser(description='Face Recognition For Tracking')
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
help='device_target')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--data_dir', type=str, default='', help='image label list file, e.g. /home/label.txt') parser.add_argument('--data_dir', type=str, default='', help='image folders')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
args = parser.parse_args() args = parser.parse_args()
graph_path = os.path.join('./graphs_graphmode', datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True,
save_graphs_path=graph_path)
if args.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)
if args.is_distributed == 0: if args.is_distributed == 0:
cfg = reid_1p_cfg if args.device_target == 'Ascend':
cfg = reid_1p_cfg_ascend
else:
cfg = reid_1p_cfg
else: else:
cfg = reid_8p_cfg if args.device_target == 'Ascend':
cfg = reid_8p_cfg_ascend
else:
cfg = reid_8p_cfg_gpu
cfg.pretrained = args.pretrained cfg.pretrained = args.pretrained
cfg.data_dir = args.data_dir cfg.data_dir = args.data_dir
@ -81,10 +95,10 @@ def init_argument():
# Show cfg # Show cfg
cfg.logger.save_args(cfg) cfg.logger.save_args(cfg)
return cfg return cfg, args
def main(): def main():
cfg = init_argument() cfg, args = init_argument()
loss_meter = AverageMeter('loss') loss_meter = AverageMeter('loss')
# dataloader # dataloader
cfg.logger.info('start create dataloader') cfg.logger.info('start create dataloader')
@ -104,7 +118,10 @@ def main():
create_network_start = time.time() create_network_start = time.time()
network = SphereNet(num_layers=cfg.net_depth, feature_dim=cfg.embedding_size, shape=cfg.input_size) network = SphereNet(num_layers=cfg.net_depth, feature_dim=cfg.embedding_size, shape=cfg.input_size)
head = CombineMarginFCFp16(embbeding_size=cfg.embedding_size, classnum=cfg.class_num) if args.device_target == 'CPU':
head = CombineMarginFC(embbeding_size=cfg.embedding_size, classnum=cfg.class_num)
else:
head = CombineMarginFCFp16(embbeding_size=cfg.embedding_size, classnum=cfg.class_num)
criterion = CrossEntropy() criterion = CrossEntropy()
# load the pretrained model # load the pretrained model
@ -122,8 +139,12 @@ def main():
cfg.logger.info('load model %s success' % cfg.pretrained) cfg.logger.info('load model %s success' % cfg.pretrained)
# mixed precision training # mixed precision training
network.add_flags_recursive(fp16=True) if args.device_target == 'CPU':
head.add_flags_recursive(fp16=True) network.add_flags_recursive(fp32=True)
head.add_flags_recursive(fp32=True)
else:
network.add_flags_recursive(fp16=True)
head.add_flags_recursive(fp16=True)
criterion.add_flags_recursive(fp32=True) criterion.add_flags_recursive(fp32=True)
train_net = BuildTrainNetworkWithHead(network, head, criterion) train_net = BuildTrainNetworkWithHead(network, head, criterion)