forked from mindspore-Ecosystem/mindspore
!9164 Add FaceAttribute network to model_zoo/research/cv/
From: @zhanghuiyao Reviewed-by: @oacjiewen Signed-off-by:
This commit is contained in:
commit
84fc72f67d
|
@ -0,0 +1,271 @@
|
|||
# Contents
|
||||
|
||||
- [Face Attribute Description](#face-attribute-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Running Example](#running-example)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [Face Attribute Description](#contents)
|
||||
|
||||
This is a Face Attributes Recognition network based on Resnet18, with support for training and evaluation on Ascend910.
|
||||
|
||||
ResNet (residual neural network) was proposed by Kaiming He and other four Chinese of Microsoft Research Institute. Through the use of ResNet unit, it successfully trained 152 layers of neural network, and won the championship in ilsvrc2015. The error rate on top 5 was 3.57%, and the parameter quantity was lower than vggnet, so the effect was very outstanding. Traditional convolution network or full connection network will have more or less information loss. At the same time, it will lead to the disappearance or explosion of gradient, which leads to the failure of deep network training. ResNet solves this problem to a certain extent. By passing the input information to the output, the integrity of the information is protected. The whole network only needs to learn the part of the difference between input and output, which simplifies the learning objectives and difficulties.The structure of ResNet can accelerate the training of neural network very quickly, and the accuracy of the model is also greatly improved. At the same time, ResNet is very popular, even can be directly used in the concept net network.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1512.03385.pdf): Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition"
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Face Attribute uses a modified-Resnet18 network for performing feature extraction.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
This network can recognize the age/gender/mask from a human face. The default rule is:
|
||||
|
||||
```python
|
||||
age:
|
||||
0: 0~2 years
|
||||
1: 3~9 years
|
||||
2: 10~19 years
|
||||
3: 20~29 years
|
||||
4: 30~39 years
|
||||
5: 40~49 years
|
||||
6: 50~59 years
|
||||
7: 60~69 years
|
||||
8: 70+ years
|
||||
|
||||
gender:
|
||||
0: male
|
||||
1: female
|
||||
|
||||
mask:
|
||||
0: wearing mask
|
||||
1: without mask
|
||||
```
|
||||
|
||||
We use about 91K face images as training dataset and 11K as evaluating dataset in this example, and you can also use your own datasets or open source datasets (e.g. FairFace and RWMFD)
|
||||
|
||||
- step 1: The dataset should be saved in a txt file, which contain the following contents:
|
||||
|
||||
```python
|
||||
[PATH_TO_IMAGE]/1.jpg [LABEL_AGE] [LABEL_GENDER] [LABEL_MASK]
|
||||
[PATH_TO_IMAGE]/2.jpg [LABEL_AGE] [LABEL_GENDER] [LABEL_MASK]
|
||||
[PATH_TO_IMAGE]/3.jpg [LABEL_AGE] [LABEL_GENDER] [LABEL_MASK]
|
||||
...
|
||||
```
|
||||
|
||||
The value range of [LABEL_AGE] is [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], -1 means the label should be ignored.
|
||||
|
||||
The value range of [LABEL_GENDER] is [-1, 0, 1], -1 means the label should be ignored.
|
||||
|
||||
The value range of [LABEL_MASK] is [-1, 0, 1], -1 means the label should be ignored.
|
||||
|
||||
- step 2: Convert the dataset to mindrecord:
|
||||
|
||||
```bash
|
||||
python src/data_to_mindrecord_train.py
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```bash
|
||||
python src/data_to_mindrecord_eval.py
|
||||
```
|
||||
|
||||
If your dataset is too big to convert at a time, you can add data to an existed mindrecord in turn:
|
||||
|
||||
```bash
|
||||
python src/data_to_mindrecord_train_append.py
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
The entire code structure is as following:
|
||||
|
||||
```python
|
||||
.
|
||||
└─ Face Attribute
|
||||
├─ README.md
|
||||
├─ scripts
|
||||
├─ run_standalone_train.sh # launch standalone training(1p) in ascend
|
||||
├─ run_distribute_train.sh # launch distributed training(8p) in ascend
|
||||
├─ run_eval.sh # launch evaluating in ascend
|
||||
└─ run_export.sh # launch exporting air model
|
||||
├─ src
|
||||
├─ FaceAttribute
|
||||
├─ cross_entropy.py # cross entroy loss
|
||||
├─ custom_net.py # network unit
|
||||
├─ loss_factory.py # loss function
|
||||
├─ head_factory.py # network head
|
||||
├─ resnet18.py # network backbone
|
||||
├─ head_factory_softmax.py # network head with softmax
|
||||
└─ resnet18_softmax.py # network backbone with softmax
|
||||
├─ config.py # parameter configuration
|
||||
├─ dataset_eval.py # dataset loading and preprocessing for evaluating
|
||||
├─ dataset_train.py # dataset loading and preprocessing for training
|
||||
├─ logging.py # log function
|
||||
├─ lrsche_factory.py # generate learning rate
|
||||
├─ data_to_mindrecord_train.py # convert dataset to mindrecord for training
|
||||
├─ data_to_mindrecord_train_append.py # add dataset to an existed mindrecord for training
|
||||
└─ data_to_mindrecord_eval.py # convert dataset to mindrecord for evaluating
|
||||
├─ train.py # training scripts
|
||||
├─ eval.py # evaluation scripts
|
||||
└─ export.py # export air model
|
||||
```
|
||||
|
||||
## [Running Example](#contents)
|
||||
|
||||
### Train
|
||||
|
||||
- Stand alone mode
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID]
|
||||
```
|
||||
|
||||
or (fine-tune)
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
|
||||
```
|
||||
|
||||
for example:
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh /home/train.mindrecord 0 /home/a.ckpt
|
||||
```
|
||||
|
||||
- Distribute mode (recommended)
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]
|
||||
```
|
||||
|
||||
or (fine-tune)
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]
|
||||
```
|
||||
|
||||
for example:
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_distribute_train.sh /home/train.mindrecord ./rank_table_8p.json /home/a.ckpt
|
||||
```
|
||||
|
||||
You will get the loss value of each step as following in "./output/[TIME]/[TIME].log" or "./scripts/device0/train.log":
|
||||
|
||||
```python
|
||||
epoch[0], iter[0], loss:4.489518, 12.92 imgs/sec
|
||||
epoch[0], iter[10], loss:3.619693, 13792.76 imgs/sec
|
||||
epoch[0], iter[20], loss:3.580932, 13817.78 imgs/sec
|
||||
epoch[0], iter[30], loss:3.574254, 7834.65 imgs/sec
|
||||
epoch[0], iter[40], loss:3.557742, 7884.87 imgs/sec
|
||||
|
||||
...
|
||||
epoch[69], iter[6120], loss:1.225308, 9561.00 imgs/sec
|
||||
epoch[69], iter[6130], loss:1.209557, 8913.28 imgs/sec
|
||||
epoch[69], iter[6140], loss:1.158641, 9755.81 imgs/sec
|
||||
epoch[69], iter[6150], loss:1.167064, 9300.77 imgs/sec
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
|
||||
```
|
||||
|
||||
for example:
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_eval.sh /home/eval.mindrecord 0 /home/a.ckpt
|
||||
```
|
||||
|
||||
You will get the result as following in "./scripts/device0/eval.log" or txt file in [PRETRAINED_BACKBONE]'s folder:
|
||||
|
||||
```python
|
||||
age accuracy: 0.45773233522001094
|
||||
gen accuracy: 0.8950155194449516
|
||||
mask accuracy: 0.992539346357495
|
||||
gen precision: 0.8869598765432098
|
||||
gen recall: 0.8907400232468036
|
||||
gen f1: 0.88884593079451
|
||||
mask precision: 1.0
|
||||
mask recall: 0.998539346357495
|
||||
mask f1: 0.9992691394116572
|
||||
```
|
||||
|
||||
### Convert model
|
||||
|
||||
If you want to infer the network on Ascend 310, you should convert the model to AIR:
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | Face Attribute |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
|
||||
| uploaded Date | 09/30/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | 91K images |
|
||||
| Training Parameters | epoch=70, batch_size=128, momentum=0.9, lr=0.001 |
|
||||
| Optimizer | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| outputs | probability |
|
||||
| Speed | 1pc: 200~250 ms/step; 8pcs: 100~150 ms/step |
|
||||
| Total time | 1pc: 2.5 hours; 8pcs: 0.3 hours |
|
||||
| Checkpoint for Fine tuning | 88M (.ckpt file) |
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Face Attribute |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 09/30/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | 11K images |
|
||||
| batch_size | 1 |
|
||||
| outputs | accuracy |
|
||||
| Accuracy(8pcs) | age:45.7% |
|
||||
| | gender:89.5% |
|
||||
| | mask:99.2% |
|
||||
| Model for inference | 88M (.ckpt file) |
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute eval."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.dataset_eval import data_generator_eval
|
||||
from src.config import config
|
||||
from src.FaceAttribute.resnet18 import get_resnet18
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
|
||||
|
||||
def softmax(x, axis=0):
|
||||
return np.exp(x) / np.sum(np.exp(x), axis=axis)
|
||||
|
||||
|
||||
def main(args):
|
||||
network = get_resnet18(args)
|
||||
ckpt_path = args.model_path
|
||||
if os.path.isfile(ckpt_path):
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
print('-----------------------load model success-----------------------')
|
||||
else:
|
||||
print('-----------------------load model failed-----------------------')
|
||||
|
||||
network.set_train(False)
|
||||
|
||||
de_dataloader, steps_per_epoch, _ = data_generator_eval(args)
|
||||
|
||||
total_data_num_age = 0
|
||||
total_data_num_gen = 0
|
||||
total_data_num_mask = 0
|
||||
age_num = 0
|
||||
gen_num = 0
|
||||
mask_num = 0
|
||||
gen_tp_num = 0
|
||||
mask_tp_num = 0
|
||||
gen_fp_num = 0
|
||||
mask_fp_num = 0
|
||||
gen_fn_num = 0
|
||||
mask_fn_num = 0
|
||||
for step_i, (data, gt_classes) in enumerate(de_dataloader):
|
||||
|
||||
print('evaluating {}/{} ...'.format(step_i + 1, steps_per_epoch))
|
||||
|
||||
data_tensor = Tensor(data, dtype=mstype.float32)
|
||||
fea = network(data_tensor)
|
||||
|
||||
gt_age, gt_gen, gt_mask = gt_classes[0]
|
||||
|
||||
age_result, gen_result, mask_result = fea
|
||||
|
||||
age_result_np = age_result.asnumpy()
|
||||
gen_result_np = gen_result.asnumpy()
|
||||
mask_result_np = mask_result.asnumpy()
|
||||
|
||||
age_prob = softmax(age_result_np[0].astype(np.float32)).tolist()
|
||||
gen_prob = softmax(gen_result_np[0].astype(np.float32)).tolist()
|
||||
mask_prob = softmax(mask_result_np[0].astype(np.float32)).tolist()
|
||||
|
||||
age = age_prob.index(max(age_prob))
|
||||
gen = gen_prob.index(max(gen_prob))
|
||||
mask = mask_prob.index(max(mask_prob))
|
||||
|
||||
if gt_age == age:
|
||||
age_num += 1
|
||||
if gt_gen == gen:
|
||||
gen_num += 1
|
||||
if gt_mask == mask:
|
||||
mask_num += 1
|
||||
|
||||
if gt_gen == 1 and gen == 1:
|
||||
gen_tp_num += 1
|
||||
if gt_gen == 0 and gen == 1:
|
||||
gen_fp_num += 1
|
||||
if gt_gen == 1 and gen == 0:
|
||||
gen_fn_num += 1
|
||||
|
||||
if gt_mask == 1 and mask == 1:
|
||||
mask_tp_num += 1
|
||||
if gt_mask == 0 and mask == 1:
|
||||
mask_fp_num += 1
|
||||
if gt_mask == 1 and mask == 0:
|
||||
mask_fn_num += 1
|
||||
|
||||
if gt_age != -1:
|
||||
total_data_num_age += 1
|
||||
if gt_gen != -1:
|
||||
total_data_num_gen += 1
|
||||
if gt_mask != -1:
|
||||
total_data_num_mask += 1
|
||||
|
||||
age_accuracy = float(age_num) / float(total_data_num_age)
|
||||
|
||||
gen_precision = float(gen_tp_num) / (float(gen_tp_num) + float(gen_fp_num))
|
||||
gen_recall = float(gen_tp_num) / (float(gen_tp_num) + float(gen_fn_num))
|
||||
gen_accuracy = float(gen_num) / float(total_data_num_gen)
|
||||
gen_f1 = 2. * gen_precision * gen_recall / (gen_precision + gen_recall)
|
||||
|
||||
mask_precision = float(mask_tp_num) / (float(mask_tp_num) + float(mask_fp_num))
|
||||
mask_recall = float(mask_tp_num) / (float(mask_tp_num) + float(mask_fn_num))
|
||||
mask_accuracy = float(mask_num) / float(total_data_num_mask)
|
||||
mask_f1 = 2. * mask_precision * mask_recall / (mask_precision + mask_recall)
|
||||
|
||||
print('model: ', ckpt_path)
|
||||
print('total age num: ', total_data_num_age)
|
||||
print('total gen num: ', total_data_num_gen)
|
||||
print('total mask num: ', total_data_num_mask)
|
||||
print('age accuracy: ', age_accuracy)
|
||||
print('gen accuracy: ', gen_accuracy)
|
||||
print('mask accuracy: ', mask_accuracy)
|
||||
print('gen precision: ', gen_precision)
|
||||
print('gen recall: ', gen_recall)
|
||||
print('gen f1: ', gen_f1)
|
||||
print('mask precision: ', mask_precision)
|
||||
print('mask recall: ', mask_recall)
|
||||
print('mask f1: ', mask_f1)
|
||||
|
||||
model_name = os.path.basename(ckpt_path).split('.')[0]
|
||||
model_dir = os.path.dirname(ckpt_path)
|
||||
result_txt = os.path.join(model_dir, model_name + '.txt')
|
||||
if os.path.exists(result_txt):
|
||||
os.remove(result_txt)
|
||||
with open(result_txt, 'a') as ft:
|
||||
ft.write('model: {}\n'.format(ckpt_path))
|
||||
ft.write('total age num: {}\n'.format(total_data_num_age))
|
||||
ft.write('total gen num: {}\n'.format(total_data_num_gen))
|
||||
ft.write('total mask num: {}\n'.format(total_data_num_mask))
|
||||
ft.write('age accuracy: {}\n'.format(age_accuracy))
|
||||
ft.write('gen accuracy: {}\n'.format(gen_accuracy))
|
||||
ft.write('mask accuracy: {}\n'.format(mask_accuracy))
|
||||
ft.write('gen precision: {}\n'.format(gen_precision))
|
||||
ft.write('gen recall: {}\n'.format(gen_recall))
|
||||
ft.write('gen f1: {}\n'.format(gen_f1))
|
||||
ft.write('mask precision: {}\n'.format(mask_precision))
|
||||
ft.write('mask recall: {}\n'.format(mask_recall))
|
||||
ft.write('mask f1: {}\n'.format(mask_f1))
|
||||
|
||||
def parse_args():
|
||||
"""parse_args"""
|
||||
parser = argparse.ArgumentParser(description='face attributes eval')
|
||||
parser.add_argument('--model_path', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_1 = parse_args()
|
||||
|
||||
args_1.dst_h = config.dst_h
|
||||
args_1.dst_w = config.dst_w
|
||||
args_1.attri_num = config.attri_num
|
||||
args_1.classes = config.classes
|
||||
args_1.flat_dim = config.flat_dim
|
||||
args_1.fc_dim = config.fc_dim
|
||||
args_1.workers = config.workers
|
||||
|
||||
main(args_1)
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert ckpt to air."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
|
||||
|
||||
from src.FaceAttribute.resnet18_softmax import get_resnet18
|
||||
from src.config import config
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
|
||||
|
||||
def main(args):
|
||||
network = get_resnet18(args)
|
||||
ckpt_path = args.model_path
|
||||
if os.path.isfile(ckpt_path):
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
print('-----------------------load model success-----------------------')
|
||||
else:
|
||||
print('-----------------------load model failed -----------------------')
|
||||
|
||||
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 112, 112)).astype(np.float32)
|
||||
tensor_input_data = Tensor(input_data)
|
||||
|
||||
export(network, tensor_input_data, file_name=ckpt_path.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'),
|
||||
file_format='AIR')
|
||||
print('-----------------------export model success-----------------------')
|
||||
|
||||
def parse_args():
|
||||
"""parse_args"""
|
||||
parser = argparse.ArgumentParser(description='Convert ckpt to air')
|
||||
parser.add_argument('--model_path', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
if __name__ == "__main__":
|
||||
args_1 = parse_args()
|
||||
|
||||
args_1.dst_h = config.dst_h
|
||||
args_1.dst_w = config.dst_w
|
||||
args_1.attri_num = config.attri_num
|
||||
args_1.classes = config.classes
|
||||
args_1.flat_dim = config.flat_dim
|
||||
args_1.fc_dim = config.fc_dim
|
||||
|
||||
main(args_1)
|
|
@ -0,0 +1,81 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]"
|
||||
echo " or: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname "$(pwd)")
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
SCRIPT_NAME='train.py'
|
||||
|
||||
rm -rf ${current_exec_path}/device*
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
MINDRECORD_FILE=$(get_real_path $1)
|
||||
RANK_TABLE=$(get_real_path $2)
|
||||
PRETRAINED_BACKBONE=''
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo $MINDRECORD_FILE
|
||||
echo $RANK_TABLE
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
export RANK_TABLE_FILE=$RANK_TABLE
|
||||
export RANK_SIZE=8
|
||||
|
||||
echo 'start training'
|
||||
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||
do
|
||||
echo 'start rank '$i
|
||||
mkdir ${current_exec_path}/device$i
|
||||
cd ${current_exec_path}/device$i || exit
|
||||
export RANK_ID=$i
|
||||
dev=`expr $i + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--mindrecord_path=$MINDRECORD_FILE \
|
||||
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
|
||||
done
|
||||
|
||||
echo 'running'
|
|
@ -0,0 +1,71 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname "$(pwd)")
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='eval.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
MINDRECORD_FILE=$(get_real_path $1)
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo $MINDRECORD_FILE
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start evaluating'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--mindrecord_path=$MINDRECORD_FILE \
|
||||
--model_path=$PRETRAINED_BACKBONE > eval.log 2>&1 &
|
||||
|
||||
echo 'running'
|
|
@ -0,0 +1,71 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname "$(pwd)")
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='export.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
BATCH_SIZE=$1
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo $BATCH_SIZE
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start converting'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--model_path=$PRETRAINED_BACKBONE > convert.log 2>&1 &
|
||||
|
||||
echo 'running'
|
|
@ -0,0 +1,77 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
echo " or: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname "$(pwd)")
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='train.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
MINDRECORD_FILE=$(get_real_path $1)
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=''
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo $MINDRECORD_FILE
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start training'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--world_size=1 \
|
||||
--mindrecord_path=$MINDRECORD_FILE \
|
||||
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
|
||||
|
||||
echo 'running'
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute cross entropy."""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class CrossEntropyWithIgnoreIndex(nn.Cell):
|
||||
'''Cross Entropy With Ignore Index Loss.'''
|
||||
def __init__(self):
|
||||
super(CrossEntropyWithIgnoreIndex, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, dtype=mstype.float32)
|
||||
self.off_value = Tensor(0.0, dtype=mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.greater = P.Greater()
|
||||
self.maximum = P.Maximum()
|
||||
self.fill = P.Fill()
|
||||
self.sum = P.ReduceSum(keep_dims=False)
|
||||
self.dtype = P.DType()
|
||||
self.relu = P.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
self.const_one = Tensor(np.ones([1]), dtype=mstype.float32)
|
||||
self.const_eps = Tensor(0.00001, dtype=mstype.float32)
|
||||
|
||||
def construct(self, x, label):
|
||||
'''Construct function.'''
|
||||
mask = self.reshape(label, (F.shape(label)[0], 1))
|
||||
mask = self.cast(mask, mstype.float32)
|
||||
mask = mask + self.const_eps
|
||||
mask = self.relu(mask)/mask
|
||||
x = x * mask
|
||||
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(x)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(x, one_hot_label)
|
||||
positive = self.sum(self.cast(self.greater(loss, self.fill(self.dtype(loss), F.shape(loss), 0.0)),
|
||||
mstype.float32), 0)
|
||||
positive = self.maximum(positive, self.const_one)
|
||||
loss = self.sum(loss, 0) / positive
|
||||
return loss
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute network unit."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn import Dense
|
||||
|
||||
|
||||
class Cut(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def bn_with_initialize(out_channels):
|
||||
bn = nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5)
|
||||
return bn
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
return Dense(input_channels, out_channels)
|
||||
|
||||
|
||||
def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
|
||||
pad_mode=pad_mode, group=groups, has_bias=False, dilation=dilation, padding=padding)
|
||||
|
||||
|
||||
def conv1x1(in_channels, out_channels, pad_mode="pad", stride=1, padding=0):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_channels, out_channels, pad_mode=pad_mode, kernel_size=1, stride=stride, has_bias=False,
|
||||
padding=padding)
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute head."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from src.FaceAttribute.custom_net import fc_with_initialize
|
||||
|
||||
__all__ = ['get_attri_head']
|
||||
|
||||
|
||||
class AttriHead(Cell):
|
||||
'''Attribute Head.'''
|
||||
def __init__(self, flat_dim, fc_dim, attri_num_list):
|
||||
super(AttriHead, self).__init__()
|
||||
self.fc1 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc1_relu = P.ReLU()
|
||||
self.fc1_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc1 = fc_with_initialize(fc_dim, attri_num_list[0])
|
||||
self.attri_fc1_relu = P.ReLU()
|
||||
self.attri_bn1 = nn.BatchNorm1d(attri_num_list[0], affine=False)
|
||||
|
||||
self.fc2 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc2_relu = P.ReLU()
|
||||
self.fc2_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc2 = fc_with_initialize(fc_dim, attri_num_list[1])
|
||||
self.attri_fc2_relu = P.ReLU()
|
||||
self.attri_bn2 = nn.BatchNorm1d(attri_num_list[1], affine=False)
|
||||
|
||||
self.fc3 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc3_relu = P.ReLU()
|
||||
self.fc3_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc3 = fc_with_initialize(fc_dim, attri_num_list[2])
|
||||
self.attri_fc3_relu = P.ReLU()
|
||||
self.attri_bn3 = nn.BatchNorm1d(attri_num_list[2], affine=False)
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
output0 = self.fc1(x)
|
||||
output0 = self.fc1_relu(output0)
|
||||
output0 = self.fc1_bn(output0)
|
||||
output0 = self.attri_fc1(output0)
|
||||
output0 = self.attri_fc1_relu(output0)
|
||||
output0 = self.attri_bn1(output0)
|
||||
|
||||
output1 = self.fc2(x)
|
||||
output1 = self.fc2_relu(output1)
|
||||
output1 = self.fc2_bn(output1)
|
||||
output1 = self.attri_fc2(output1)
|
||||
output1 = self.attri_fc2_relu(output1)
|
||||
output1 = self.attri_bn2(output1)
|
||||
|
||||
output2 = self.fc3(x)
|
||||
output2 = self.fc3_relu(output2)
|
||||
output2 = self.fc3_bn(output2)
|
||||
output2 = self.attri_fc3(output2)
|
||||
output2 = self.attri_fc3_relu(output2)
|
||||
output2 = self.attri_bn3(output2)
|
||||
|
||||
return output0, output1, output2
|
||||
|
||||
|
||||
def get_attri_head(flat_dim, fc_dim, attri_num_list):
|
||||
attri_head = AttriHead(flat_dim, fc_dim, attri_num_list)
|
||||
return attri_head
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute head."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from src.FaceAttribute.custom_net import fc_with_initialize
|
||||
|
||||
__all__ = ['get_attri_head']
|
||||
|
||||
|
||||
class AttriHead(Cell):
|
||||
'''Attribute Head.'''
|
||||
def __init__(self, flat_dim, fc_dim, attri_num_list):
|
||||
super(AttriHead, self).__init__()
|
||||
self.fc1 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc1_relu = P.ReLU()
|
||||
self.fc1_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc1 = fc_with_initialize(fc_dim, attri_num_list[0])
|
||||
self.attri_fc1_relu = P.ReLU()
|
||||
self.attri_bn1 = nn.BatchNorm1d(attri_num_list[0], affine=False)
|
||||
self.softmax1 = P.Softmax()
|
||||
|
||||
self.fc2 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc2_relu = P.ReLU()
|
||||
self.fc2_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc2 = fc_with_initialize(fc_dim, attri_num_list[1])
|
||||
self.attri_fc2_relu = P.ReLU()
|
||||
self.attri_bn2 = nn.BatchNorm1d(attri_num_list[1], affine=False)
|
||||
self.softmax2 = P.Softmax()
|
||||
|
||||
self.fc3 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc3_relu = P.ReLU()
|
||||
self.fc3_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc3 = fc_with_initialize(fc_dim, attri_num_list[2])
|
||||
self.attri_fc3_relu = P.ReLU()
|
||||
self.attri_bn3 = nn.BatchNorm1d(attri_num_list[2], affine=False)
|
||||
self.softmax3 = P.Softmax()
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
output0 = self.fc1(x)
|
||||
output0 = self.fc1_relu(output0)
|
||||
output0 = self.fc1_bn(output0)
|
||||
output0 = self.attri_fc1(output0)
|
||||
output0 = self.attri_fc1_relu(output0)
|
||||
output0 = self.attri_bn1(output0)
|
||||
output0 = self.softmax1(output0)
|
||||
|
||||
output1 = self.fc2(x)
|
||||
output1 = self.fc2_relu(output1)
|
||||
output1 = self.fc2_bn(output1)
|
||||
output1 = self.attri_fc2(output1)
|
||||
output1 = self.attri_fc2_relu(output1)
|
||||
output1 = self.attri_bn2(output1)
|
||||
output1 = self.softmax2(output1)
|
||||
|
||||
output2 = self.fc3(x)
|
||||
output2 = self.fc3_relu(output2)
|
||||
output2 = self.fc3_bn(output2)
|
||||
output2 = self.attri_fc3(output2)
|
||||
output2 = self.attri_fc3_relu(output2)
|
||||
output2 = self.attri_bn3(output2)
|
||||
output2 = self.softmax3(output2)
|
||||
|
||||
return output0, output1, output2
|
||||
|
||||
|
||||
def get_attri_head(flat_dim, fc_dim, attri_num_list):
|
||||
attri_head = AttriHead(flat_dim, fc_dim, attri_num_list)
|
||||
return attri_head
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute loss."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.FaceAttribute.cross_entropy import CrossEntropyWithIgnoreIndex
|
||||
|
||||
|
||||
__all__ = ['get_loss']
|
||||
|
||||
|
||||
class CriterionsFaceAttri(nn.Cell):
|
||||
'''Criterions Face Attribute.'''
|
||||
def __init__(self):
|
||||
super(CriterionsFaceAttri, self).__init__()
|
||||
|
||||
# label
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.squeeze = P.Squeeze(axis=1)
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
self.mean = P.ReduceMean()
|
||||
|
||||
self.label0_param = Tensor([0], dtype=mstype.int32)
|
||||
self.label1_param = Tensor([1], dtype=mstype.int32)
|
||||
self.label2_param = Tensor([2], dtype=mstype.int32)
|
||||
|
||||
# loss
|
||||
self.ce_ignore_loss = CrossEntropyWithIgnoreIndex()
|
||||
self.printn = P.Print()
|
||||
|
||||
def construct(self, x0, x1, x2, label):
|
||||
'''Construct function.'''
|
||||
# each sub attribute loss
|
||||
label0 = self.squeeze(self.gatherv2(label, self.label0_param, 1))
|
||||
loss0 = self.ce_ignore_loss(x0, label0)
|
||||
|
||||
label1 = self.squeeze(self.gatherv2(label, self.label1_param, 1))
|
||||
loss1 = self.ce_ignore_loss(x1, label1)
|
||||
|
||||
label2 = self.squeeze(self.gatherv2(label, self.label2_param, 1))
|
||||
loss2 = self.ce_ignore_loss(x2, label2)
|
||||
|
||||
loss = loss0 + loss1 + loss2
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def get_loss():
|
||||
return CriterionsFaceAttri()
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute resnet18 backbone."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import TensorAdd
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from src.FaceAttribute.custom_net import Cut, bn_with_initialize, conv1x1, conv3x3
|
||||
from src.FaceAttribute.head_factory import get_attri_head
|
||||
|
||||
__all__ = ['get_resnet18']
|
||||
|
||||
|
||||
class IRBlock(Cell):
|
||||
'''IRBlock.'''
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(IRBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride=stride)
|
||||
self.bn1 = bn_with_initialize(planes)
|
||||
self.relu1 = P.ReLU()
|
||||
self.conv2 = conv3x3(planes, planes, stride=1)
|
||||
self.bn2 = bn_with_initialize(planes)
|
||||
|
||||
if downsample is None:
|
||||
self.downsample = Cut()
|
||||
else:
|
||||
self.downsample = downsample
|
||||
|
||||
self.add = TensorAdd()
|
||||
self.cast = P.Cast()
|
||||
self.relu2 = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
identity = self.downsample(x)
|
||||
out = self.add(out, identity)
|
||||
out = self.relu2(out)
|
||||
return out
|
||||
|
||||
|
||||
class DownSample(Cell):
|
||||
def __init__(self, inplanes, planes, expansion, stride):
|
||||
super(DownSample, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes * expansion, stride=stride, pad_mode="valid")
|
||||
self.bn1 = bn_with_initialize(planes * expansion)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
return out
|
||||
|
||||
|
||||
class MakeLayer(Cell):
|
||||
'''Make layer.'''
|
||||
def __init__(self, block, inplanes, planes, blocks, stride=1):
|
||||
super(MakeLayer, self).__init__()
|
||||
|
||||
self.inplanes = inplanes
|
||||
self.downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
self.downsample = DownSample(self.inplanes, planes, block.expansion, stride)
|
||||
|
||||
self.layers = []
|
||||
self.layers.append(block(self.inplanes, planes, stride, self.downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
self.layers.append(block(self.inplanes, planes))
|
||||
self.layers = nn.CellList(self.layers)
|
||||
|
||||
def construct(self, x):
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttriResNet(Cell):
|
||||
'''Resnet for attribute.'''
|
||||
def __init__(self, block, layers, flat_dim, fc_dim, attri_num_list):
|
||||
super(AttriResNet, self).__init__()
|
||||
|
||||
# resnet18
|
||||
self.inplanes = 32
|
||||
self.conv1 = conv3x3(3, self.inplanes, stride=1)
|
||||
self.bn1 = bn_with_initialize(self.inplanes)
|
||||
self.relu = P.ReLU()
|
||||
self.layer1 = MakeLayer(block, inplanes=32, planes=64, blocks=layers[0], stride=2)
|
||||
self.layer2 = MakeLayer(block, inplanes=64, planes=128, blocks=layers[1], stride=2)
|
||||
self.layer3 = MakeLayer(block, inplanes=128, planes=256, blocks=layers[2], stride=2)
|
||||
self.layer4 = MakeLayer(block, inplanes=256, planes=512, blocks=layers[3], stride=2)
|
||||
|
||||
# avg global pooling
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.head = get_attri_head(flat_dim, fc_dim, attri_num_list)
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.mean(x, (2, 3))
|
||||
b, c, _, _ = self.shape(x)
|
||||
x = self.reshape(x, (b, c))
|
||||
return self.head(x)
|
||||
|
||||
|
||||
def get_resnet18(args):
|
||||
'''Build resnet18 for attribute.'''
|
||||
flat_dim = args.flat_dim
|
||||
fc_dim = args.fc_dim
|
||||
str_classes = args.classes.strip().split(',')
|
||||
if args.attri_num != len(str_classes):
|
||||
print('args warning: attri_num != classes num')
|
||||
return None
|
||||
attri_num_list = []
|
||||
for i, _ in enumerate(str_classes):
|
||||
attri_num_list.append(int(str_classes[i]))
|
||||
|
||||
attri_resnet18 = AttriResNet(IRBlock, (2, 2, 2, 2), flat_dim, fc_dim, attri_num_list)
|
||||
|
||||
return attri_resnet18
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute resnet18 backbone."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import TensorAdd
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from src.FaceAttribute.custom_net import Cut, bn_with_initialize, conv1x1, conv3x3
|
||||
from src.FaceAttribute.head_factory_softmax import get_attri_head
|
||||
|
||||
__all__ = ['get_resnet18']
|
||||
|
||||
|
||||
class IRBlock(Cell):
|
||||
'''IRBlock.'''
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(IRBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride=stride)
|
||||
self.bn1 = bn_with_initialize(planes)
|
||||
self.relu1 = P.ReLU()
|
||||
self.conv2 = conv3x3(planes, planes, stride=1)
|
||||
self.bn2 = bn_with_initialize(planes)
|
||||
|
||||
if downsample is None:
|
||||
self.downsample = Cut()
|
||||
else:
|
||||
self.downsample = downsample
|
||||
|
||||
self.add = TensorAdd()
|
||||
self.cast = P.Cast()
|
||||
self.relu2 = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
identity = self.downsample(x)
|
||||
out = self.add(out, identity)
|
||||
out = self.relu2(out)
|
||||
return out
|
||||
|
||||
|
||||
class DownSample(Cell):
|
||||
def __init__(self, inplanes, planes, expansion, stride):
|
||||
super(DownSample, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes * expansion, stride=stride, pad_mode="valid")
|
||||
self.bn1 = bn_with_initialize(planes * expansion)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
return out
|
||||
|
||||
|
||||
class MakeLayer(Cell):
|
||||
'''Make layer function.'''
|
||||
def __init__(self, block, inplanes, planes, blocks, stride=1):
|
||||
super(MakeLayer, self).__init__()
|
||||
|
||||
self.inplanes = inplanes
|
||||
self.downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
self.downsample = DownSample(self.inplanes, planes, block.expansion, stride)
|
||||
|
||||
self.layers = []
|
||||
self.layers.append(block(self.inplanes, planes, stride, self.downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
self.layers.append(block(self.inplanes, planes))
|
||||
self.layers = nn.CellList(self.layers)
|
||||
|
||||
def construct(self, x):
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttriResNet(Cell):
|
||||
'''Resnet for attribute.'''
|
||||
def __init__(self, block, layers, flat_dim, fc_dim, attri_num_list):
|
||||
super(AttriResNet, self).__init__()
|
||||
|
||||
# resnet18
|
||||
self.inplanes = 32
|
||||
self.conv1 = conv3x3(3, self.inplanes, stride=1)
|
||||
self.bn1 = bn_with_initialize(self.inplanes)
|
||||
self.relu = P.ReLU()
|
||||
self.layer1 = MakeLayer(block, inplanes=32, planes=64, blocks=layers[0], stride=2)
|
||||
self.layer2 = MakeLayer(block, inplanes=64, planes=128, blocks=layers[1], stride=2)
|
||||
self.layer3 = MakeLayer(block, inplanes=128, planes=256, blocks=layers[2], stride=2)
|
||||
self.layer4 = MakeLayer(block, inplanes=256, planes=512, blocks=layers[3], stride=2)
|
||||
|
||||
# avg global pooling
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.head = get_attri_head(flat_dim, fc_dim, attri_num_list)
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.mean(x, (2, 3))
|
||||
b, c, _, _ = self.shape(x)
|
||||
x = self.reshape(x, (b, c))
|
||||
return self.head(x)
|
||||
|
||||
|
||||
def get_resnet18(args):
|
||||
'''Build resnet18 for attribute.'''
|
||||
flat_dim = args.flat_dim
|
||||
fc_dim = args.fc_dim
|
||||
str_classes = args.classes.strip().split(',')
|
||||
if args.attri_num != len(str_classes):
|
||||
print('args warning: attri_num != classes num')
|
||||
return None
|
||||
attri_num_list = []
|
||||
for i, _ in enumerate(str_classes):
|
||||
attri_num_list.append(int(str_classes[i]))
|
||||
|
||||
attri_resnet18 = AttriResNet(IRBlock, (2, 2, 2, 2), flat_dim, fc_dim, attri_num_list)
|
||||
|
||||
return attri_resnet18
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Network config setting, will be used in train.py and eval.py"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
'per_batch_size': 128,
|
||||
'dst_h': 112,
|
||||
'dst_w': 112,
|
||||
'workers': 8,
|
||||
'attri_num': 3,
|
||||
'classes': '9,2,2',
|
||||
'backbone': 'resnet18',
|
||||
'loss_scale': 1024,
|
||||
'flat_dim': 512,
|
||||
'fc_dim': 256,
|
||||
'lr': 0.009,
|
||||
'lr_scale': 1,
|
||||
'lr_epochs': [20, 30, 50],
|
||||
'weight_decay': 0.0005,
|
||||
'momentum': 0.9,
|
||||
'max_epoch': 70,
|
||||
'warmup_epochs': 0,
|
||||
'log_interval': 10,
|
||||
'ckpt_path': '../../output',
|
||||
|
||||
# data_to_mindrecord parameter
|
||||
'eval_dataset_txt_file': 'Your_label_txt_file',
|
||||
'eval_mindrecord_file_name': 'Your_output_path/data_test.mindrecord',
|
||||
'train_dataset_txt_file': 'Your_label_txt_file',
|
||||
'train_mindrecord_file_name': 'Your_output_path/data_train.mindrecord',
|
||||
'train_append_dataset_txt_file': 'Your_label_txt_file',
|
||||
'train_append_mindrecord_file_name': 'Your_previous_output_path/data_train.mindrecord0'
|
||||
})
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert dataset to mindrecord for evaluating Face attribute."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from config import config
|
||||
|
||||
dataset_txt_file = config.eval_dataset_txt_file
|
||||
|
||||
mindrecord_file_name = config.eval_mindrecord_file_name
|
||||
|
||||
mindrecord_num = 8
|
||||
|
||||
|
||||
def convert_data_to_mindrecord():
|
||||
'''Convert data to mindrecord.'''
|
||||
writer = FileWriter(mindrecord_file_name, mindrecord_num)
|
||||
attri_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"label": {"type": "int32", "shape": [-1]}
|
||||
}
|
||||
|
||||
print('Loading eval data...')
|
||||
total_data = []
|
||||
with open(dataset_txt_file, 'r') as ft:
|
||||
lines = ft.readlines()
|
||||
for line in lines:
|
||||
sline = line.strip().split(" ")
|
||||
image_file = sline[0]
|
||||
labels = []
|
||||
for item in sline[1:]:
|
||||
labels.append(int(item))
|
||||
|
||||
with open(image_file, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
data = {
|
||||
"image": img,
|
||||
"label": np.array(labels, dtype='int32')
|
||||
}
|
||||
|
||||
total_data.append(data)
|
||||
|
||||
print('Writing eval data to mindrecord...')
|
||||
writer.add_schema(attri_json, "attri_json")
|
||||
if total_data is None:
|
||||
raise ValueError("None needs writing to mindrecord.")
|
||||
writer.write_raw_data(total_data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
convert_data_to_mindrecord()
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert dataset to mindrecord for training Face attribute."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from config import config
|
||||
|
||||
dataset_txt_file = config.train_dataset_txt_file
|
||||
|
||||
mindrecord_file_name = config.train_mindrecord_file_name
|
||||
|
||||
mindrecord_num = 8
|
||||
|
||||
|
||||
def convert_data_to_mindrecord():
|
||||
'''Covert data to mindrecord.'''
|
||||
writer = FileWriter(mindrecord_file_name, mindrecord_num)
|
||||
attri_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"label": {"type": "int32", "shape": [-1]}
|
||||
}
|
||||
|
||||
print('Loading train data...')
|
||||
total_data = []
|
||||
with open(dataset_txt_file, 'r') as ft:
|
||||
lines = ft.readlines()
|
||||
for line in lines:
|
||||
sline = line.strip().split(" ")
|
||||
image_file = sline[0]
|
||||
labels = []
|
||||
for item in sline[1:]:
|
||||
labels.append(int(item))
|
||||
|
||||
with open(image_file, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
data = {
|
||||
"image": img,
|
||||
"label": np.array(labels, dtype='int32')
|
||||
}
|
||||
|
||||
total_data.append(data)
|
||||
|
||||
print('Writing train data to mindrecord...')
|
||||
writer.add_schema(attri_json, "attri_json")
|
||||
if total_data is None:
|
||||
raise ValueError("None needs writing to mindrecord.")
|
||||
writer.write_raw_data(total_data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
convert_data_to_mindrecord()
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Add dataset to an existed mindrecord for training Face attribute."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from config import config
|
||||
|
||||
dataset_txt_file = config.train_append_dataset_txt_file
|
||||
|
||||
mindrecord_file_name = config.train_append_mindrecord_file_name
|
||||
|
||||
mindrecord_num = 8
|
||||
|
||||
|
||||
def convert_data_to_mindrecord():
|
||||
'''Covert data to mindrecord.'''
|
||||
print('Loading mindrecord...')
|
||||
writer = FileWriter.open_for_append(mindrecord_file_name)
|
||||
|
||||
print('Loading train data...')
|
||||
total_data = []
|
||||
with open(dataset_txt_file, 'r') as ft:
|
||||
lines = ft.readlines()
|
||||
for line in lines:
|
||||
sline = line.strip().split(" ")
|
||||
image_file = sline[0]
|
||||
labels = []
|
||||
for item in sline[1:]:
|
||||
labels.append(int(item))
|
||||
|
||||
with open(image_file, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
data = {
|
||||
"image": img,
|
||||
"label": np.array(labels, dtype='int32')
|
||||
}
|
||||
|
||||
total_data.append(data)
|
||||
|
||||
print('Writing train data to mindrecord...')
|
||||
if total_data is None:
|
||||
raise ValueError("None needs writing to mindrecord.")
|
||||
writer.write_raw_data(total_data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
convert_data_to_mindrecord()
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute dataset for eval"""
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.py_transforms as F
|
||||
import mindspore.dataset.transforms.py_transforms as F2
|
||||
|
||||
__all__ = ['data_generator_eval']
|
||||
|
||||
|
||||
def data_generator_eval(args):
|
||||
'''Build eval dataloader.'''
|
||||
mindrecord_path = args.mindrecord_path
|
||||
dst_w = args.dst_w
|
||||
dst_h = args.dst_h
|
||||
batch_size = 1
|
||||
attri_num = args.attri_num
|
||||
transform_img = F2.Compose([F.Decode(),
|
||||
F.Resize((dst_w, dst_h)),
|
||||
F.ToTensor(),
|
||||
F.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
de_dataset = de.MindDataset(mindrecord_path + "0", columns_list=["image", "label"])
|
||||
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=args.workers,
|
||||
python_multiprocessing=True)
|
||||
de_dataset = de_dataset.batch(batch_size)
|
||||
|
||||
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
|
||||
steps_per_epoch = de_dataset.get_dataset_size()
|
||||
print("image number:{0}".format(steps_per_epoch))
|
||||
num_classes = attri_num
|
||||
|
||||
return de_dataloader, steps_per_epoch, num_classes
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute dataset for train"""
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.py_transforms as F
|
||||
import mindspore.dataset.transforms.py_transforms as F2
|
||||
|
||||
__all__ = ['data_generator']
|
||||
|
||||
|
||||
def data_generator(args):
|
||||
'''Build train dataloader.'''
|
||||
mindrecord_path = args.mindrecord_path
|
||||
dst_w = args.dst_w
|
||||
dst_h = args.dst_h
|
||||
batch_size = args.per_batch_size
|
||||
attri_num = args.attri_num
|
||||
max_epoch = args.max_epoch
|
||||
transform_img = F2.Compose([F.Decode(),
|
||||
F.Resize((dst_w, dst_h)),
|
||||
F.RandomHorizontalFlip(prob=0.5),
|
||||
F.ToTensor(),
|
||||
F.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
de_dataset = de.MindDataset(mindrecord_path + "0", columns_list=["image", "label"], num_shards=args.world_size,
|
||||
shard_id=args.local_rank)
|
||||
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=args.workers,
|
||||
python_multiprocessing=True)
|
||||
de_dataset = de_dataset.batch(batch_size, drop_remainder=True)
|
||||
steps_per_epoch = de_dataset.get_dataset_size()
|
||||
de_dataset = de_dataset.repeat(max_epoch)
|
||||
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
|
||||
|
||||
num_classes = attri_num
|
||||
|
||||
return de_dataloader, steps_per_epoch, num_classes
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom logger."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
logger_name_1 = 'face_attributes'
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
'''Logger.'''
|
||||
def __init__(self, logger_name):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
self.local_rank = 0
|
||||
|
||||
def setup_logging_file(self, log_dir, local_rank=0):
|
||||
'''Setup logging file.'''
|
||||
self.local_rank = local_rank
|
||||
if self.local_rank == 0:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '.log'
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
logger = LOGGER(logger_name_1)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
||||
|
||||
|
||||
class AverageMeter():
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f', tb_writer=None):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
self.tb_writer = tb_writer
|
||||
self.cur_step = 1
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
|
||||
self.cur_step += 1
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name}:{avg' + self.fmt + '}'
|
||||
return fmtstr.format(**self.__dict__)
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute learning rate scheduler."""
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
|
||||
def warmup_step(args, gamma=0.1):
|
||||
'''Warmup step.'''
|
||||
base_lr = args.lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(args.max_epoch * args.steps_per_epoch)
|
||||
warmup_steps = int(args.warmup_epochs * args.steps_per_epoch)
|
||||
milestones = args.lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * args.steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_learning_rate(i, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma**milestones_steps_counter[i]
|
||||
yield lr
|
|
@ -0,0 +1,232 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute train."""
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
import argparse
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.communication.management import get_group_size, init, get_rank
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.FaceAttribute.resnet18 import get_resnet18
|
||||
from src.FaceAttribute.loss_factory import get_loss
|
||||
from src.dataset_train import data_generator
|
||||
from src.lrsche_factory import warmup_step
|
||||
from src.logging import get_logger, AverageMeter
|
||||
from src.config import config
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
|
||||
|
||||
class BuildTrainNetwork(nn.Cell):
|
||||
'''Build train network.'''
|
||||
def __init__(self, network, criterion):
|
||||
super(BuildTrainNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.criterion = criterion
|
||||
self.print = P.Print()
|
||||
|
||||
def construct(self, input_data, label):
|
||||
logit0, logit1, logit2 = self.network(input_data)
|
||||
loss = self.criterion(logit0, logit1, logit2, label)
|
||||
return loss
|
||||
|
||||
|
||||
def parse_args():
|
||||
'''Argument for Face Attributes.'''
|
||||
parser = argparse.ArgumentParser('Face Attributes')
|
||||
|
||||
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
|
||||
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
|
||||
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def train():
|
||||
'''train function.'''
|
||||
# logger
|
||||
args = parse_args()
|
||||
|
||||
# init distributed
|
||||
if args.world_size != 1:
|
||||
init()
|
||||
args.local_rank = get_rank()
|
||||
args.world_size = get_group_size()
|
||||
|
||||
args.per_batch_size = config.per_batch_size
|
||||
args.dst_h = config.dst_h
|
||||
args.dst_w = config.dst_w
|
||||
args.workers = config.workers
|
||||
args.attri_num = config.attri_num
|
||||
args.classes = config.classes
|
||||
args.backbone = config.backbone
|
||||
args.loss_scale = config.loss_scale
|
||||
args.flat_dim = config.flat_dim
|
||||
args.fc_dim = config.fc_dim
|
||||
args.lr = config.lr
|
||||
args.lr_scale = config.lr_scale
|
||||
args.lr_epochs = config.lr_epochs
|
||||
args.weight_decay = config.weight_decay
|
||||
args.momentum = config.momentum
|
||||
args.max_epoch = config.max_epoch
|
||||
args.warmup_epochs = config.warmup_epochs
|
||||
args.log_interval = config.log_interval
|
||||
args.ckpt_path = config.ckpt_path
|
||||
|
||||
if args.world_size == 1:
|
||||
args.per_batch_size = 256
|
||||
else:
|
||||
args.lr = args.lr * 4.
|
||||
|
||||
if args.world_size != 1:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
else:
|
||||
parallel_mode = ParallelMode.STAND_ALONE
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.world_size)
|
||||
|
||||
# model and log save path
|
||||
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.local_rank)
|
||||
loss_meter = AverageMeter('loss')
|
||||
|
||||
# dataloader
|
||||
args.logger.info('start create dataloader')
|
||||
de_dataloader, steps_per_epoch, num_classes = data_generator(args)
|
||||
args.steps_per_epoch = steps_per_epoch
|
||||
args.num_classes = num_classes
|
||||
args.logger.info('end create dataloader')
|
||||
args.logger.save_args(args)
|
||||
|
||||
# backbone and loss
|
||||
args.logger.important_info('start create network')
|
||||
create_network_start = time.time()
|
||||
network = get_resnet18(args)
|
||||
|
||||
criterion = get_loss()
|
||||
|
||||
# load pretrain model
|
||||
if os.path.isfile(args.pretrained):
|
||||
param_dict = load_checkpoint(args.pretrained)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load model {} success'.format(args.pretrained))
|
||||
|
||||
# optimizer and lr scheduler
|
||||
lr = warmup_step(args, gamma=0.1)
|
||||
opt = Momentum(params=network.trainable_params(),
|
||||
learning_rate=lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale)
|
||||
|
||||
train_net = BuildTrainNetwork(network, criterion)
|
||||
|
||||
# mixed precision training
|
||||
criterion.add_flags_recursive(fp32=True)
|
||||
|
||||
# package training process
|
||||
train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
# checkpoint
|
||||
if args.local_rank == 0:
|
||||
ckpt_max_num = args.max_epoch
|
||||
train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
|
||||
ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = train_net
|
||||
cb_params.epoch_num = ckpt_max_num
|
||||
cb_params.cur_epoch_num = 0
|
||||
run_context = RunContext(cb_params)
|
||||
ckpt_cb.begin(run_context)
|
||||
|
||||
train_net.set_train()
|
||||
t_end = time.time()
|
||||
t_epoch = time.time()
|
||||
old_progress = -1
|
||||
|
||||
i = 0
|
||||
for _, (data, gt_classes) in enumerate(de_dataloader):
|
||||
|
||||
data_tensor = Tensor(data, dtype=mstype.float32)
|
||||
gt_tensor = Tensor(gt_classes, dtype=mstype.int32)
|
||||
|
||||
loss = train_net(data_tensor, gt_tensor)
|
||||
loss_meter.update(loss.asnumpy()[0])
|
||||
|
||||
# save ckpt
|
||||
if args.local_rank == 0:
|
||||
cb_params.cur_step_num = i + 1
|
||||
cb_params.batch_num = i + 2
|
||||
ckpt_cb.step_end(run_context)
|
||||
|
||||
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
|
||||
cb_params.cur_epoch_num += 1
|
||||
|
||||
# save Log
|
||||
if i == 0:
|
||||
time_for_graph_compile = time.time() - create_network_start
|
||||
args.logger.important_info('{}, graph compile time={:.2f}s'.format(args.backbone, time_for_graph_compile))
|
||||
|
||||
if i % args.log_interval == 0 and args.local_rank == 0:
|
||||
time_used = time.time() - t_end
|
||||
epoch = int(i / args.steps_per_epoch)
|
||||
fps = args.per_batch_size * (i - old_progress) * args.world_size / time_used
|
||||
args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
|
||||
|
||||
t_end = time.time()
|
||||
loss_meter.reset()
|
||||
old_progress = i
|
||||
|
||||
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
|
||||
epoch_time_used = time.time() - t_epoch
|
||||
epoch = int(i / args.steps_per_epoch)
|
||||
fps = args.per_batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
|
||||
args.logger.info('=================================================')
|
||||
args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
|
||||
args.logger.info('=================================================')
|
||||
t_epoch = time.time()
|
||||
|
||||
i += 1
|
||||
|
||||
args.logger.info('--------- trains out ---------')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
Loading…
Reference in New Issue