Add FaceAttribute net to model_zoo/research/cv/

This commit is contained in:
zhanghuiyao 2020-11-28 15:43:58 +08:00
parent 6cf308076d
commit 2705bb2ca5
25 changed files with 2165 additions and 0 deletions

View File

@ -0,0 +1,271 @@
# Contents
- [Face Attribute Description](#face-attribute-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Running Example](#running-example)
- [Model Description](#model-description)
- [Performance](#performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Face Attribute Description](#contents)
This is a Face Attributes Recognition network based on Resnet18, with support for training and evaluation on Ascend910.
ResNet (residual neural network) was proposed by Kaiming He and other four Chinese of Microsoft Research Institute. Through the use of ResNet unit, it successfully trained 152 layers of neural network, and won the championship in ilsvrc2015. The error rate on top 5 was 3.57%, and the parameter quantity was lower than vggnet, so the effect was very outstanding. Traditional convolution network or full connection network will have more or less information loss. At the same time, it will lead to the disappearance or explosion of gradient, which leads to the failure of deep network training. ResNet solves this problem to a certain extent. By passing the input information to the output, the integrity of the information is protected. The whole network only needs to learn the part of the difference between input and output, which simplifies the learning objectives and difficulties.The structure of ResNet can accelerate the training of neural network very quickly, and the accuracy of the model is also greatly improved. At the same time, ResNet is very popular, even can be directly used in the concept net network.
[Paper](https://arxiv.org/pdf/1512.03385.pdf): Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition"
# [Model Architecture](#contents)
Face Attribute uses a modified-Resnet18 network for performing feature extraction.
# [Dataset](#contents)
This network can recognize the age/gender/mask from a human face. The default rule is:
```python
age:
0: 0~2 years
1: 3~9 years
2: 10~19 years
3: 20~29 years
4: 30~39 years
5: 40~49 years
6: 50~59 years
7: 60~69 years
8: 70+ years
gender:
0: male
1: female
mask:
0: wearing mask
1: without mask
```
We use about 91K face images as training dataset and 11K as evaluating dataset in this example, and you can also use your own datasets or open source datasets (e.g. FairFace and RWMFD)
- step 1: The dataset should be saved in a txt file, which contain the following contents:
```python
[PATH_TO_IMAGE]/1.jpg [LABEL_AGE] [LABEL_GENDER] [LABEL_MASK]
[PATH_TO_IMAGE]/2.jpg [LABEL_AGE] [LABEL_GENDER] [LABEL_MASK]
[PATH_TO_IMAGE]/3.jpg [LABEL_AGE] [LABEL_GENDER] [LABEL_MASK]
...
```
The value range of [LABEL_AGE] is [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], -1 means the label should be ignored.
The value range of [LABEL_GENDER] is [-1, 0, 1], -1 means the label should be ignored.
The value range of [LABEL_MASK] is [-1, 0, 1], -1 means the label should be ignored.
- step 2: Convert the dataset to mindrecord:
```bash
python src/data_to_mindrecord_train.py
```
or
```bash
python src/data_to_mindrecord_eval.py
```
If your dataset is too big to convert at a time, you can add data to an existed mindrecord in turn:
```bash
python src/data_to_mindrecord_train_append.py
```
# [Environment Requirements](#contents)
- Hardware(Ascend)
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below:
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
The entire code structure is as following:
```python
.
└─ Face Attribute
├─ README.md
├─ scripts
├─ run_standalone_train.sh # launch standalone training(1p) in ascend
├─ run_distribute_train.sh # launch distributed training(8p) in ascend
├─ run_eval.sh # launch evaluating in ascend
└─ run_export.sh # launch exporting air model
├─ src
├─ FaceAttribute
├─ cross_entropy.py # cross entroy loss
├─ custom_net.py # network unit
├─ loss_factory.py # loss function
├─ head_factory.py # network head
├─ resnet18.py # network backbone
├─ head_factory_softmax.py # network head with softmax
└─ resnet18_softmax.py # network backbone with softmax
├─ config.py # parameter configuration
├─ dataset_eval.py # dataset loading and preprocessing for evaluating
├─ dataset_train.py # dataset loading and preprocessing for training
├─ logging.py # log function
├─ lrsche_factory.py # generate learning rate
├─ data_to_mindrecord_train.py # convert dataset to mindrecord for training
├─ data_to_mindrecord_train_append.py # add dataset to an existed mindrecord for training
└─ data_to_mindrecord_eval.py # convert dataset to mindrecord for evaluating
├─ train.py # training scripts
├─ eval.py # evaluation scripts
└─ export.py # export air model
```
## [Running Example](#contents)
### Train
- Stand alone mode
```bash
cd ./scripts
sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID]
```
or (fine-tune)
```bash
cd ./scripts
sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
```
for example:
```bash
cd ./scripts
sh run_standalone_train.sh /home/train.mindrecord 0 /home/a.ckpt
```
- Distribute mode (recommended)
```bash
cd ./scripts
sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]
```
or (fine-tune)
```bash
cd ./scripts
sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]
```
for example:
```bash
cd ./scripts
sh run_distribute_train.sh /home/train.mindrecord ./rank_table_8p.json /home/a.ckpt
```
You will get the loss value of each step as following in "./output/[TIME]/[TIME].log" or "./scripts/device0/train.log":
```python
epoch[0], iter[0], loss:4.489518, 12.92 imgs/sec
epoch[0], iter[10], loss:3.619693, 13792.76 imgs/sec
epoch[0], iter[20], loss:3.580932, 13817.78 imgs/sec
epoch[0], iter[30], loss:3.574254, 7834.65 imgs/sec
epoch[0], iter[40], loss:3.557742, 7884.87 imgs/sec
...
epoch[69], iter[6120], loss:1.225308, 9561.00 imgs/sec
epoch[69], iter[6130], loss:1.209557, 8913.28 imgs/sec
epoch[69], iter[6140], loss:1.158641, 9755.81 imgs/sec
epoch[69], iter[6150], loss:1.167064, 9300.77 imgs/sec
```
### Evaluation
```bash
cd ./scripts
sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
```
for example:
```bash
cd ./scripts
sh run_eval.sh /home/eval.mindrecord 0 /home/a.ckpt
```
You will get the result as following in "./scripts/device0/eval.log" or txt file in [PRETRAINED_BACKBONE]'s folder:
```python
age accuracy: 0.45773233522001094
gen accuracy: 0.8950155194449516
mask accuracy: 0.992539346357495
gen precision: 0.8869598765432098
gen recall: 0.8907400232468036
gen f1: 0.88884593079451
mask precision: 1.0
mask recall: 0.998539346357495
mask f1: 0.9992691394116572
```
### Convert model
If you want to infer the network on Ascend 310, you should convert the model to AIR:
```bash
cd ./scripts
sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | Face Attribute |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | V1 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 09/30/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 91K images |
| Training Parameters | epoch=70, batch_size=128, momentum=0.9, lr=0.001 |
| Optimizer | Momentum |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Speed | 1pc: 200~250 ms/step; 8pcs: 100~150 ms/step |
| Total time | 1pc: 2.5 hours; 8pcs: 0.3 hours |
| Checkpoint for Fine tuning | 88M (.ckpt file) |
### Evaluation Performance
| Parameters | Face Attribute |
| ------------------- | --------------------------- |
| Model Version | V1 |
| Resource | Ascend 910 |
| Uploaded Date | 09/30/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 11K images |
| batch_size | 1 |
| outputs | accuracy |
| Accuracy(8pcs) | age:45.7% |
| | gender:89.5% |
| | mask:99.2% |
| Model for inference | 88M (.ckpt file) |
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,189 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute eval."""
import os
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import dtype as mstype
from src.dataset_eval import data_generator_eval
from src.config import config
from src.FaceAttribute.resnet18 import get_resnet18
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
def softmax(x, axis=0):
return np.exp(x) / np.sum(np.exp(x), axis=axis)
def main(args):
network = get_resnet18(args)
ckpt_path = args.model_path
if os.path.isfile(ckpt_path):
param_dict = load_checkpoint(ckpt_path)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
print('-----------------------load model success-----------------------')
else:
print('-----------------------load model failed-----------------------')
network.set_train(False)
de_dataloader, steps_per_epoch, _ = data_generator_eval(args)
total_data_num_age = 0
total_data_num_gen = 0
total_data_num_mask = 0
age_num = 0
gen_num = 0
mask_num = 0
gen_tp_num = 0
mask_tp_num = 0
gen_fp_num = 0
mask_fp_num = 0
gen_fn_num = 0
mask_fn_num = 0
for step_i, (data, gt_classes) in enumerate(de_dataloader):
print('evaluating {}/{} ...'.format(step_i + 1, steps_per_epoch))
data_tensor = Tensor(data, dtype=mstype.float32)
fea = network(data_tensor)
gt_age, gt_gen, gt_mask = gt_classes[0]
age_result, gen_result, mask_result = fea
age_result_np = age_result.asnumpy()
gen_result_np = gen_result.asnumpy()
mask_result_np = mask_result.asnumpy()
age_prob = softmax(age_result_np[0].astype(np.float32)).tolist()
gen_prob = softmax(gen_result_np[0].astype(np.float32)).tolist()
mask_prob = softmax(mask_result_np[0].astype(np.float32)).tolist()
age = age_prob.index(max(age_prob))
gen = gen_prob.index(max(gen_prob))
mask = mask_prob.index(max(mask_prob))
if gt_age == age:
age_num += 1
if gt_gen == gen:
gen_num += 1
if gt_mask == mask:
mask_num += 1
if gt_gen == 1 and gen == 1:
gen_tp_num += 1
if gt_gen == 0 and gen == 1:
gen_fp_num += 1
if gt_gen == 1 and gen == 0:
gen_fn_num += 1
if gt_mask == 1 and mask == 1:
mask_tp_num += 1
if gt_mask == 0 and mask == 1:
mask_fp_num += 1
if gt_mask == 1 and mask == 0:
mask_fn_num += 1
if gt_age != -1:
total_data_num_age += 1
if gt_gen != -1:
total_data_num_gen += 1
if gt_mask != -1:
total_data_num_mask += 1
age_accuracy = float(age_num) / float(total_data_num_age)
gen_precision = float(gen_tp_num) / (float(gen_tp_num) + float(gen_fp_num))
gen_recall = float(gen_tp_num) / (float(gen_tp_num) + float(gen_fn_num))
gen_accuracy = float(gen_num) / float(total_data_num_gen)
gen_f1 = 2. * gen_precision * gen_recall / (gen_precision + gen_recall)
mask_precision = float(mask_tp_num) / (float(mask_tp_num) + float(mask_fp_num))
mask_recall = float(mask_tp_num) / (float(mask_tp_num) + float(mask_fn_num))
mask_accuracy = float(mask_num) / float(total_data_num_mask)
mask_f1 = 2. * mask_precision * mask_recall / (mask_precision + mask_recall)
print('model: ', ckpt_path)
print('total age num: ', total_data_num_age)
print('total gen num: ', total_data_num_gen)
print('total mask num: ', total_data_num_mask)
print('age accuracy: ', age_accuracy)
print('gen accuracy: ', gen_accuracy)
print('mask accuracy: ', mask_accuracy)
print('gen precision: ', gen_precision)
print('gen recall: ', gen_recall)
print('gen f1: ', gen_f1)
print('mask precision: ', mask_precision)
print('mask recall: ', mask_recall)
print('mask f1: ', mask_f1)
model_name = os.path.basename(ckpt_path).split('.')[0]
model_dir = os.path.dirname(ckpt_path)
result_txt = os.path.join(model_dir, model_name + '.txt')
if os.path.exists(result_txt):
os.remove(result_txt)
with open(result_txt, 'a') as ft:
ft.write('model: {}\n'.format(ckpt_path))
ft.write('total age num: {}\n'.format(total_data_num_age))
ft.write('total gen num: {}\n'.format(total_data_num_gen))
ft.write('total mask num: {}\n'.format(total_data_num_mask))
ft.write('age accuracy: {}\n'.format(age_accuracy))
ft.write('gen accuracy: {}\n'.format(gen_accuracy))
ft.write('mask accuracy: {}\n'.format(mask_accuracy))
ft.write('gen precision: {}\n'.format(gen_precision))
ft.write('gen recall: {}\n'.format(gen_recall))
ft.write('gen f1: {}\n'.format(gen_f1))
ft.write('mask precision: {}\n'.format(mask_precision))
ft.write('mask recall: {}\n'.format(mask_recall))
ft.write('mask f1: {}\n'.format(mask_f1))
def parse_args():
"""parse_args"""
parser = argparse.ArgumentParser(description='face attributes eval')
parser.add_argument('--model_path', type=str, default='', help='pretrained model to load')
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
args_opt = parser.parse_args()
return args_opt
if __name__ == '__main__':
args_1 = parse_args()
args_1.dst_h = config.dst_h
args_1.dst_w = config.dst_w
args_1.attri_num = config.attri_num
args_1.classes = config.classes
args_1.flat_dim = config.flat_dim
args_1.fc_dim = config.fc_dim
args_1.workers = config.workers
main(args_1)

View File

@ -0,0 +1,75 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert ckpt to air."""
import os
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.FaceAttribute.resnet18_softmax import get_resnet18
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
def main(args):
network = get_resnet18(args)
ckpt_path = args.model_path
if os.path.isfile(ckpt_path):
param_dict = load_checkpoint(ckpt_path)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
print('-----------------------load model success-----------------------')
else:
print('-----------------------load model failed -----------------------')
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 112, 112)).astype(np.float32)
tensor_input_data = Tensor(input_data)
export(network, tensor_input_data, file_name=ckpt_path.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'),
file_format='AIR')
print('-----------------------export model success-----------------------')
def parse_args():
"""parse_args"""
parser = argparse.ArgumentParser(description='Convert ckpt to air')
parser.add_argument('--model_path', type=str, default='', help='pretrained model to load')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
args_opt = parser.parse_args()
return args_opt
if __name__ == "__main__":
args_1 = parse_args()
args_1.dst_h = config.dst_h
args_1.dst_w = config.dst_w
args_1.attri_num = config.attri_num
args_1.classes = config.classes
args_1.flat_dim = config.flat_dim
args_1.fc_dim = config.fc_dim
main(args_1)

View File

@ -0,0 +1,81 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]"
echo " or: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname "$(pwd)")
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
SCRIPT_NAME='train.py'
rm -rf ${current_exec_path}/device*
ulimit -c unlimited
MINDRECORD_FILE=$(get_real_path $1)
RANK_TABLE=$(get_real_path $2)
PRETRAINED_BACKBONE=''
if [ $# == 3 ]
then
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
fi
echo $MINDRECORD_FILE
echo $RANK_TABLE
echo $PRETRAINED_BACKBONE
export RANK_TABLE_FILE=$RANK_TABLE
export RANK_SIZE=8
echo 'start training'
for((i=0;i<=$RANK_SIZE-1;i++));
do
echo 'start rank '$i
mkdir ${current_exec_path}/device$i
cd ${current_exec_path}/device$i || exit
export RANK_ID=$i
dev=`expr $i + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--mindrecord_path=$MINDRECORD_FILE \
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
done
echo 'running'

View File

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname "$(pwd)")
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
SCRIPT_NAME='eval.py'
ulimit -c unlimited
MINDRECORD_FILE=$(get_real_path $1)
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
echo $MINDRECORD_FILE
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
echo 'start evaluating'
export RANK_ID=0
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
echo 'start device '$USE_DEVICE_ID
mkdir ${current_exec_path}/device$USE_DEVICE_ID
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--mindrecord_path=$MINDRECORD_FILE \
--model_path=$PRETRAINED_BACKBONE > eval.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname "$(pwd)")
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
SCRIPT_NAME='export.py'
ulimit -c unlimited
BATCH_SIZE=$1
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
echo $BATCH_SIZE
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
echo 'start converting'
export RANK_ID=0
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
echo 'start device '$USE_DEVICE_ID
mkdir ${current_exec_path}/device$USE_DEVICE_ID
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--batch_size=$BATCH_SIZE \
--model_path=$PRETRAINED_BACKBONE > convert.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,77 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
echo " or: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
current_exec_path=$(pwd)
echo ${current_exec_path}
dirname_path=$(dirname "$(pwd)")
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
export RANK_SIZE=1
SCRIPT_NAME='train.py'
ulimit -c unlimited
MINDRECORD_FILE=$(get_real_path $1)
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=''
if [ $# == 3 ]
then
PRETRAINED_BACKBONE=$(get_real_path $3)
if [ ! -f $PRETRAINED_BACKBONE ]
then
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
exit 1
fi
fi
echo $MINDRECORD_FILE
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
echo 'start training'
export RANK_ID=0
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
echo 'start device '$USE_DEVICE_ID
mkdir ${current_exec_path}/device$USE_DEVICE_ID
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--world_size=1 \
--mindrecord_path=$MINDRECORD_FILE \
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute cross entropy."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
class CrossEntropyWithIgnoreIndex(nn.Cell):
'''Cross Entropy With Ignore Index Loss.'''
def __init__(self):
super(CrossEntropyWithIgnoreIndex, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, dtype=mstype.float32)
self.off_value = Tensor(0.0, dtype=mstype.float32)
self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.greater = P.Greater()
self.maximum = P.Maximum()
self.fill = P.Fill()
self.sum = P.ReduceSum(keep_dims=False)
self.dtype = P.DType()
self.relu = P.ReLU()
self.reshape = P.Reshape()
self.const_one = Tensor(np.ones([1]), dtype=mstype.float32)
self.const_eps = Tensor(0.00001, dtype=mstype.float32)
def construct(self, x, label):
'''Construct function.'''
mask = self.reshape(label, (F.shape(label)[0], 1))
mask = self.cast(mask, mstype.float32)
mask = mask + self.const_eps
mask = self.relu(mask)/mask
x = x * mask
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(x)[1], self.on_value, self.off_value)
loss = self.ce(x, one_hot_label)
positive = self.sum(self.cast(self.greater(loss, self.fill(self.dtype(loss), F.shape(loss), 0.0)),
mstype.float32), 0)
positive = self.maximum(positive, self.const_one)
loss = self.sum(loss, 0) / positive
return loss

View File

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute network unit."""
import mindspore.nn as nn
from mindspore.nn import Dense
class Cut(nn.Cell):
def construct(self, x):
return x
def bn_with_initialize(out_channels):
bn = nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5)
return bn
def fc_with_initialize(input_channels, out_channels):
return Dense(input_channels, out_channels)
def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
pad_mode=pad_mode, group=groups, has_bias=False, dilation=dilation, padding=padding)
def conv1x1(in_channels, out_channels, pad_mode="pad", stride=1, padding=0):
"""1x1 convolution"""
return nn.Conv2d(in_channels, out_channels, pad_mode=pad_mode, kernel_size=1, stride=stride, has_bias=False,
padding=padding)

View File

@ -0,0 +1,78 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute head."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.nn import Cell
from src.FaceAttribute.custom_net import fc_with_initialize
__all__ = ['get_attri_head']
class AttriHead(Cell):
'''Attribute Head.'''
def __init__(self, flat_dim, fc_dim, attri_num_list):
super(AttriHead, self).__init__()
self.fc1 = fc_with_initialize(flat_dim, fc_dim)
self.fc1_relu = P.ReLU()
self.fc1_bn = nn.BatchNorm1d(fc_dim, affine=False)
self.attri_fc1 = fc_with_initialize(fc_dim, attri_num_list[0])
self.attri_fc1_relu = P.ReLU()
self.attri_bn1 = nn.BatchNorm1d(attri_num_list[0], affine=False)
self.fc2 = fc_with_initialize(flat_dim, fc_dim)
self.fc2_relu = P.ReLU()
self.fc2_bn = nn.BatchNorm1d(fc_dim, affine=False)
self.attri_fc2 = fc_with_initialize(fc_dim, attri_num_list[1])
self.attri_fc2_relu = P.ReLU()
self.attri_bn2 = nn.BatchNorm1d(attri_num_list[1], affine=False)
self.fc3 = fc_with_initialize(flat_dim, fc_dim)
self.fc3_relu = P.ReLU()
self.fc3_bn = nn.BatchNorm1d(fc_dim, affine=False)
self.attri_fc3 = fc_with_initialize(fc_dim, attri_num_list[2])
self.attri_fc3_relu = P.ReLU()
self.attri_bn3 = nn.BatchNorm1d(attri_num_list[2], affine=False)
def construct(self, x):
'''Construct function.'''
output0 = self.fc1(x)
output0 = self.fc1_relu(output0)
output0 = self.fc1_bn(output0)
output0 = self.attri_fc1(output0)
output0 = self.attri_fc1_relu(output0)
output0 = self.attri_bn1(output0)
output1 = self.fc2(x)
output1 = self.fc2_relu(output1)
output1 = self.fc2_bn(output1)
output1 = self.attri_fc2(output1)
output1 = self.attri_fc2_relu(output1)
output1 = self.attri_bn2(output1)
output2 = self.fc3(x)
output2 = self.fc3_relu(output2)
output2 = self.fc3_bn(output2)
output2 = self.attri_fc3(output2)
output2 = self.attri_fc3_relu(output2)
output2 = self.attri_bn3(output2)
return output0, output1, output2
def get_attri_head(flat_dim, fc_dim, attri_num_list):
attri_head = AttriHead(flat_dim, fc_dim, attri_num_list)
return attri_head

View File

@ -0,0 +1,84 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute head."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.nn import Cell
from src.FaceAttribute.custom_net import fc_with_initialize
__all__ = ['get_attri_head']
class AttriHead(Cell):
'''Attribute Head.'''
def __init__(self, flat_dim, fc_dim, attri_num_list):
super(AttriHead, self).__init__()
self.fc1 = fc_with_initialize(flat_dim, fc_dim)
self.fc1_relu = P.ReLU()
self.fc1_bn = nn.BatchNorm1d(fc_dim, affine=False)
self.attri_fc1 = fc_with_initialize(fc_dim, attri_num_list[0])
self.attri_fc1_relu = P.ReLU()
self.attri_bn1 = nn.BatchNorm1d(attri_num_list[0], affine=False)
self.softmax1 = P.Softmax()
self.fc2 = fc_with_initialize(flat_dim, fc_dim)
self.fc2_relu = P.ReLU()
self.fc2_bn = nn.BatchNorm1d(fc_dim, affine=False)
self.attri_fc2 = fc_with_initialize(fc_dim, attri_num_list[1])
self.attri_fc2_relu = P.ReLU()
self.attri_bn2 = nn.BatchNorm1d(attri_num_list[1], affine=False)
self.softmax2 = P.Softmax()
self.fc3 = fc_with_initialize(flat_dim, fc_dim)
self.fc3_relu = P.ReLU()
self.fc3_bn = nn.BatchNorm1d(fc_dim, affine=False)
self.attri_fc3 = fc_with_initialize(fc_dim, attri_num_list[2])
self.attri_fc3_relu = P.ReLU()
self.attri_bn3 = nn.BatchNorm1d(attri_num_list[2], affine=False)
self.softmax3 = P.Softmax()
def construct(self, x):
'''Construct function.'''
output0 = self.fc1(x)
output0 = self.fc1_relu(output0)
output0 = self.fc1_bn(output0)
output0 = self.attri_fc1(output0)
output0 = self.attri_fc1_relu(output0)
output0 = self.attri_bn1(output0)
output0 = self.softmax1(output0)
output1 = self.fc2(x)
output1 = self.fc2_relu(output1)
output1 = self.fc2_bn(output1)
output1 = self.attri_fc2(output1)
output1 = self.attri_fc2_relu(output1)
output1 = self.attri_bn2(output1)
output1 = self.softmax2(output1)
output2 = self.fc3(x)
output2 = self.fc3_relu(output2)
output2 = self.fc3_bn(output2)
output2 = self.attri_fc3(output2)
output2 = self.attri_fc3_relu(output2)
output2 = self.attri_bn3(output2)
output2 = self.softmax3(output2)
return output0, output1, output2
def get_attri_head(flat_dim, fc_dim, attri_num_list):
attri_head = AttriHead(flat_dim, fc_dim, attri_num_list)
return attri_head

View File

@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute loss."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.common import dtype as mstype
from src.FaceAttribute.cross_entropy import CrossEntropyWithIgnoreIndex
__all__ = ['get_loss']
class CriterionsFaceAttri(nn.Cell):
'''Criterions Face Attribute.'''
def __init__(self):
super(CriterionsFaceAttri, self).__init__()
# label
self.gatherv2 = P.GatherV2()
self.squeeze = P.Squeeze(axis=1)
self.cast = P.Cast()
self.reshape = P.Reshape()
self.mean = P.ReduceMean()
self.label0_param = Tensor([0], dtype=mstype.int32)
self.label1_param = Tensor([1], dtype=mstype.int32)
self.label2_param = Tensor([2], dtype=mstype.int32)
# loss
self.ce_ignore_loss = CrossEntropyWithIgnoreIndex()
self.printn = P.Print()
def construct(self, x0, x1, x2, label):
'''Construct function.'''
# each sub attribute loss
label0 = self.squeeze(self.gatherv2(label, self.label0_param, 1))
loss0 = self.ce_ignore_loss(x0, label0)
label1 = self.squeeze(self.gatherv2(label, self.label1_param, 1))
loss1 = self.ce_ignore_loss(x1, label1)
label2 = self.squeeze(self.gatherv2(label, self.label2_param, 1))
loss2 = self.ce_ignore_loss(x2, label2)
loss = loss0 + loss1 + loss2
return loss
def get_loss():
return CriterionsFaceAttri()

View File

@ -0,0 +1,145 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute resnet18 backbone."""
import mindspore.nn as nn
from mindspore.ops.operations import TensorAdd
from mindspore.ops import operations as P
from mindspore.nn import Cell
from src.FaceAttribute.custom_net import Cut, bn_with_initialize, conv1x1, conv3x3
from src.FaceAttribute.head_factory import get_attri_head
__all__ = ['get_resnet18']
class IRBlock(Cell):
'''IRBlock.'''
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(IRBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride=stride)
self.bn1 = bn_with_initialize(planes)
self.relu1 = P.ReLU()
self.conv2 = conv3x3(planes, planes, stride=1)
self.bn2 = bn_with_initialize(planes)
if downsample is None:
self.downsample = Cut()
else:
self.downsample = downsample
self.add = TensorAdd()
self.cast = P.Cast()
self.relu2 = P.ReLU()
def construct(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
identity = self.downsample(x)
out = self.add(out, identity)
out = self.relu2(out)
return out
class DownSample(Cell):
def __init__(self, inplanes, planes, expansion, stride):
super(DownSample, self).__init__()
self.conv1 = conv1x1(inplanes, planes * expansion, stride=stride, pad_mode="valid")
self.bn1 = bn_with_initialize(planes * expansion)
def construct(self, x):
out = self.conv1(x)
out = self.bn1(out)
return out
class MakeLayer(Cell):
'''Make layer.'''
def __init__(self, block, inplanes, planes, blocks, stride=1):
super(MakeLayer, self).__init__()
self.inplanes = inplanes
self.downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
self.downsample = DownSample(self.inplanes, planes, block.expansion, stride)
self.layers = []
self.layers.append(block(self.inplanes, planes, stride, self.downsample))
self.inplanes = planes
for _ in range(1, blocks):
self.layers.append(block(self.inplanes, planes))
self.layers = nn.CellList(self.layers)
def construct(self, x):
for block in self.layers:
x = block(x)
return x
class AttriResNet(Cell):
'''Resnet for attribute.'''
def __init__(self, block, layers, flat_dim, fc_dim, attri_num_list):
super(AttriResNet, self).__init__()
# resnet18
self.inplanes = 32
self.conv1 = conv3x3(3, self.inplanes, stride=1)
self.bn1 = bn_with_initialize(self.inplanes)
self.relu = P.ReLU()
self.layer1 = MakeLayer(block, inplanes=32, planes=64, blocks=layers[0], stride=2)
self.layer2 = MakeLayer(block, inplanes=64, planes=128, blocks=layers[1], stride=2)
self.layer3 = MakeLayer(block, inplanes=128, planes=256, blocks=layers[2], stride=2)
self.layer4 = MakeLayer(block, inplanes=256, planes=512, blocks=layers[3], stride=2)
# avg global pooling
self.mean = P.ReduceMean(keep_dims=True)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.head = get_attri_head(flat_dim, fc_dim, attri_num_list)
def construct(self, x):
'''Construct function.'''
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.mean(x, (2, 3))
b, c, _, _ = self.shape(x)
x = self.reshape(x, (b, c))
return self.head(x)
def get_resnet18(args):
'''Build resnet18 for attribute.'''
flat_dim = args.flat_dim
fc_dim = args.fc_dim
str_classes = args.classes.strip().split(',')
if args.attri_num != len(str_classes):
print('args warning: attri_num != classes num')
return None
attri_num_list = []
for i, _ in enumerate(str_classes):
attri_num_list.append(int(str_classes[i]))
attri_resnet18 = AttriResNet(IRBlock, (2, 2, 2, 2), flat_dim, fc_dim, attri_num_list)
return attri_resnet18

View File

@ -0,0 +1,145 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute resnet18 backbone."""
import mindspore.nn as nn
from mindspore.ops.operations import TensorAdd
from mindspore.ops import operations as P
from mindspore.nn import Cell
from src.FaceAttribute.custom_net import Cut, bn_with_initialize, conv1x1, conv3x3
from src.FaceAttribute.head_factory_softmax import get_attri_head
__all__ = ['get_resnet18']
class IRBlock(Cell):
'''IRBlock.'''
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(IRBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride=stride)
self.bn1 = bn_with_initialize(planes)
self.relu1 = P.ReLU()
self.conv2 = conv3x3(planes, planes, stride=1)
self.bn2 = bn_with_initialize(planes)
if downsample is None:
self.downsample = Cut()
else:
self.downsample = downsample
self.add = TensorAdd()
self.cast = P.Cast()
self.relu2 = P.ReLU()
def construct(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
identity = self.downsample(x)
out = self.add(out, identity)
out = self.relu2(out)
return out
class DownSample(Cell):
def __init__(self, inplanes, planes, expansion, stride):
super(DownSample, self).__init__()
self.conv1 = conv1x1(inplanes, planes * expansion, stride=stride, pad_mode="valid")
self.bn1 = bn_with_initialize(planes * expansion)
def construct(self, x):
out = self.conv1(x)
out = self.bn1(out)
return out
class MakeLayer(Cell):
'''Make layer function.'''
def __init__(self, block, inplanes, planes, blocks, stride=1):
super(MakeLayer, self).__init__()
self.inplanes = inplanes
self.downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
self.downsample = DownSample(self.inplanes, planes, block.expansion, stride)
self.layers = []
self.layers.append(block(self.inplanes, planes, stride, self.downsample))
self.inplanes = planes
for _ in range(1, blocks):
self.layers.append(block(self.inplanes, planes))
self.layers = nn.CellList(self.layers)
def construct(self, x):
for block in self.layers:
x = block(x)
return x
class AttriResNet(Cell):
'''Resnet for attribute.'''
def __init__(self, block, layers, flat_dim, fc_dim, attri_num_list):
super(AttriResNet, self).__init__()
# resnet18
self.inplanes = 32
self.conv1 = conv3x3(3, self.inplanes, stride=1)
self.bn1 = bn_with_initialize(self.inplanes)
self.relu = P.ReLU()
self.layer1 = MakeLayer(block, inplanes=32, planes=64, blocks=layers[0], stride=2)
self.layer2 = MakeLayer(block, inplanes=64, planes=128, blocks=layers[1], stride=2)
self.layer3 = MakeLayer(block, inplanes=128, planes=256, blocks=layers[2], stride=2)
self.layer4 = MakeLayer(block, inplanes=256, planes=512, blocks=layers[3], stride=2)
# avg global pooling
self.mean = P.ReduceMean(keep_dims=True)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.head = get_attri_head(flat_dim, fc_dim, attri_num_list)
def construct(self, x):
'''Construct function.'''
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.mean(x, (2, 3))
b, c, _, _ = self.shape(x)
x = self.reshape(x, (b, c))
return self.head(x)
def get_resnet18(args):
'''Build resnet18 for attribute.'''
flat_dim = args.flat_dim
fc_dim = args.fc_dim
str_classes = args.classes.strip().split(',')
if args.attri_num != len(str_classes):
print('args warning: attri_num != classes num')
return None
attri_num_list = []
for i, _ in enumerate(str_classes):
attri_num_list.append(int(str_classes[i]))
attri_resnet18 = AttriResNet(IRBlock, (2, 2, 2, 2), flat_dim, fc_dim, attri_num_list)
return attri_resnet18

View File

@ -0,0 +1,46 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Network config setting, will be used in train.py and eval.py"""
from easydict import EasyDict as ed
config = ed({
'per_batch_size': 128,
'dst_h': 112,
'dst_w': 112,
'workers': 8,
'attri_num': 3,
'classes': '9,2,2',
'backbone': 'resnet18',
'loss_scale': 1024,
'flat_dim': 512,
'fc_dim': 256,
'lr': 0.009,
'lr_scale': 1,
'lr_epochs': [20, 30, 50],
'weight_decay': 0.0005,
'momentum': 0.9,
'max_epoch': 70,
'warmup_epochs': 0,
'log_interval': 10,
'ckpt_path': '../../output',
# data_to_mindrecord parameter
'eval_dataset_txt_file': 'Your_label_txt_file',
'eval_mindrecord_file_name': 'Your_output_path/data_test.mindrecord',
'train_dataset_txt_file': 'Your_label_txt_file',
'train_mindrecord_file_name': 'Your_output_path/data_train.mindrecord',
'train_append_dataset_txt_file': 'Your_label_txt_file',
'train_append_mindrecord_file_name': 'Your_previous_output_path/data_train.mindrecord0'
})

View File

@ -0,0 +1,66 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert dataset to mindrecord for evaluating Face attribute."""
import numpy as np
from mindspore.mindrecord import FileWriter
from config import config
dataset_txt_file = config.eval_dataset_txt_file
mindrecord_file_name = config.eval_mindrecord_file_name
mindrecord_num = 8
def convert_data_to_mindrecord():
'''Convert data to mindrecord.'''
writer = FileWriter(mindrecord_file_name, mindrecord_num)
attri_json = {
"image": {"type": "bytes"},
"label": {"type": "int32", "shape": [-1]}
}
print('Loading eval data...')
total_data = []
with open(dataset_txt_file, 'r') as ft:
lines = ft.readlines()
for line in lines:
sline = line.strip().split(" ")
image_file = sline[0]
labels = []
for item in sline[1:]:
labels.append(int(item))
with open(image_file, 'rb') as f:
img = f.read()
data = {
"image": img,
"label": np.array(labels, dtype='int32')
}
total_data.append(data)
print('Writing eval data to mindrecord...')
writer.add_schema(attri_json, "attri_json")
if total_data is None:
raise ValueError("None needs writing to mindrecord.")
writer.write_raw_data(total_data)
writer.commit()
convert_data_to_mindrecord()

View File

@ -0,0 +1,66 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert dataset to mindrecord for training Face attribute."""
import numpy as np
from mindspore.mindrecord import FileWriter
from config import config
dataset_txt_file = config.train_dataset_txt_file
mindrecord_file_name = config.train_mindrecord_file_name
mindrecord_num = 8
def convert_data_to_mindrecord():
'''Covert data to mindrecord.'''
writer = FileWriter(mindrecord_file_name, mindrecord_num)
attri_json = {
"image": {"type": "bytes"},
"label": {"type": "int32", "shape": [-1]}
}
print('Loading train data...')
total_data = []
with open(dataset_txt_file, 'r') as ft:
lines = ft.readlines()
for line in lines:
sline = line.strip().split(" ")
image_file = sline[0]
labels = []
for item in sline[1:]:
labels.append(int(item))
with open(image_file, 'rb') as f:
img = f.read()
data = {
"image": img,
"label": np.array(labels, dtype='int32')
}
total_data.append(data)
print('Writing train data to mindrecord...')
writer.add_schema(attri_json, "attri_json")
if total_data is None:
raise ValueError("None needs writing to mindrecord.")
writer.write_raw_data(total_data)
writer.commit()
convert_data_to_mindrecord()

View File

@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Add dataset to an existed mindrecord for training Face attribute."""
import numpy as np
from mindspore.mindrecord import FileWriter
from config import config
dataset_txt_file = config.train_append_dataset_txt_file
mindrecord_file_name = config.train_append_mindrecord_file_name
mindrecord_num = 8
def convert_data_to_mindrecord():
'''Covert data to mindrecord.'''
print('Loading mindrecord...')
writer = FileWriter.open_for_append(mindrecord_file_name)
print('Loading train data...')
total_data = []
with open(dataset_txt_file, 'r') as ft:
lines = ft.readlines()
for line in lines:
sline = line.strip().split(" ")
image_file = sline[0]
labels = []
for item in sline[1:]:
labels.append(int(item))
with open(image_file, 'rb') as f:
img = f.read()
data = {
"image": img,
"label": np.array(labels, dtype='int32')
}
total_data.append(data)
print('Writing train data to mindrecord...')
if total_data is None:
raise ValueError("None needs writing to mindrecord.")
writer.write_raw_data(total_data)
writer.commit()
convert_data_to_mindrecord()

View File

@ -0,0 +1,45 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute dataset for eval"""
import mindspore.dataset as de
import mindspore.dataset.vision.py_transforms as F
import mindspore.dataset.transforms.py_transforms as F2
__all__ = ['data_generator_eval']
def data_generator_eval(args):
'''Build eval dataloader.'''
mindrecord_path = args.mindrecord_path
dst_w = args.dst_w
dst_h = args.dst_h
batch_size = 1
attri_num = args.attri_num
transform_img = F2.Compose([F.Decode(),
F.Resize((dst_w, dst_h)),
F.ToTensor(),
F.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
de_dataset = de.MindDataset(mindrecord_path + "0", columns_list=["image", "label"])
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=args.workers,
python_multiprocessing=True)
de_dataset = de_dataset.batch(batch_size)
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
steps_per_epoch = de_dataset.get_dataset_size()
print("image number:{0}".format(steps_per_epoch))
num_classes = attri_num
return de_dataloader, steps_per_epoch, num_classes

View File

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute dataset for train"""
import mindspore.dataset as de
import mindspore.dataset.vision.py_transforms as F
import mindspore.dataset.transforms.py_transforms as F2
__all__ = ['data_generator']
def data_generator(args):
'''Build train dataloader.'''
mindrecord_path = args.mindrecord_path
dst_w = args.dst_w
dst_h = args.dst_h
batch_size = args.per_batch_size
attri_num = args.attri_num
max_epoch = args.max_epoch
transform_img = F2.Compose([F.Decode(),
F.Resize((dst_w, dst_h)),
F.RandomHorizontalFlip(prob=0.5),
F.ToTensor(),
F.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
de_dataset = de.MindDataset(mindrecord_path + "0", columns_list=["image", "label"], num_shards=args.world_size,
shard_id=args.local_rank)
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=args.workers,
python_multiprocessing=True)
de_dataset = de_dataset.batch(batch_size, drop_remainder=True)
steps_per_epoch = de_dataset.get_dataset_size()
de_dataset = de_dataset.repeat(max_epoch)
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
num_classes = attri_num
return de_dataloader, steps_per_epoch, num_classes

View File

@ -0,0 +1,105 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom logger."""
import logging
import os
import sys
from datetime import datetime
logger_name_1 = 'face_attributes'
class LOGGER(logging.Logger):
'''Logger.'''
def __init__(self, logger_name):
super(LOGGER, self).__init__(logger_name)
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
self.local_rank = 0
def setup_logging_file(self, log_dir, local_rank=0):
'''Setup logging file.'''
self.local_rank = local_rank
if self.local_rank == 0:
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '.log'
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
logger = LOGGER(logger_name_1)
logger.setup_logging_file(path, rank)
return logger
class AverageMeter():
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)

View File

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute learning rate scheduler."""
from collections import Counter
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def warmup_step(args, gamma=0.1):
'''Warmup step.'''
base_lr = args.lr
warmup_init_lr = 0
total_steps = int(args.max_epoch * args.steps_per_epoch)
warmup_steps = int(args.warmup_epochs * args.steps_per_epoch)
milestones = args.lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * args.steps_per_epoch
milestones_steps.append(milestones_step)
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_learning_rate(i, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
yield lr

View File

@ -0,0 +1,232 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute train."""
import os
import time
import datetime
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim import Momentum
from mindspore.communication.management import get_group_size, init, get_rank
from mindspore.nn import TrainOneStepCell
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from src.FaceAttribute.resnet18 import get_resnet18
from src.FaceAttribute.loss_factory import get_loss
from src.dataset_train import data_generator
from src.lrsche_factory import warmup_step
from src.logging import get_logger, AverageMeter
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
class BuildTrainNetwork(nn.Cell):
'''Build train network.'''
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
self.print = P.Print()
def construct(self, input_data, label):
logit0, logit1, logit2 = self.network(input_data)
loss = self.criterion(logit0, logit1, logit2, label)
return loss
def parse_args():
'''Argument for Face Attributes.'''
parser = argparse.ArgumentParser('Face Attributes')
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
args, _ = parser.parse_known_args()
return args
def train():
'''train function.'''
# logger
args = parse_args()
# init distributed
if args.world_size != 1:
init()
args.local_rank = get_rank()
args.world_size = get_group_size()
args.per_batch_size = config.per_batch_size
args.dst_h = config.dst_h
args.dst_w = config.dst_w
args.workers = config.workers
args.attri_num = config.attri_num
args.classes = config.classes
args.backbone = config.backbone
args.loss_scale = config.loss_scale
args.flat_dim = config.flat_dim
args.fc_dim = config.fc_dim
args.lr = config.lr
args.lr_scale = config.lr_scale
args.lr_epochs = config.lr_epochs
args.weight_decay = config.weight_decay
args.momentum = config.momentum
args.max_epoch = config.max_epoch
args.warmup_epochs = config.warmup_epochs
args.log_interval = config.log_interval
args.ckpt_path = config.ckpt_path
if args.world_size == 1:
args.per_batch_size = 256
else:
args.lr = args.lr * 4.
if args.world_size != 1:
parallel_mode = ParallelMode.DATA_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.world_size)
# model and log save path
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.local_rank)
loss_meter = AverageMeter('loss')
# dataloader
args.logger.info('start create dataloader')
de_dataloader, steps_per_epoch, num_classes = data_generator(args)
args.steps_per_epoch = steps_per_epoch
args.num_classes = num_classes
args.logger.info('end create dataloader')
args.logger.save_args(args)
# backbone and loss
args.logger.important_info('start create network')
create_network_start = time.time()
network = get_resnet18(args)
criterion = get_loss()
# load pretrain model
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(args.pretrained))
# optimizer and lr scheduler
lr = warmup_step(args, gamma=0.1)
opt = Momentum(params=network.trainable_params(),
learning_rate=lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
train_net = BuildTrainNetwork(network, criterion)
# mixed precision training
criterion.add_flags_recursive(fp32=True)
# package training process
train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale)
context.reset_auto_parallel_context()
# checkpoint
if args.local_rank == 0:
ckpt_max_num = args.max_epoch
train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
cb_params = _InternalCallbackParam()
cb_params.train_network = train_net
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 0
run_context = RunContext(cb_params)
ckpt_cb.begin(run_context)
train_net.set_train()
t_end = time.time()
t_epoch = time.time()
old_progress = -1
i = 0
for _, (data, gt_classes) in enumerate(de_dataloader):
data_tensor = Tensor(data, dtype=mstype.float32)
gt_tensor = Tensor(gt_classes, dtype=mstype.int32)
loss = train_net(data_tensor, gt_tensor)
loss_meter.update(loss.asnumpy()[0])
# save ckpt
if args.local_rank == 0:
cb_params.cur_step_num = i + 1
cb_params.batch_num = i + 2
ckpt_cb.step_end(run_context)
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
cb_params.cur_epoch_num += 1
# save Log
if i == 0:
time_for_graph_compile = time.time() - create_network_start
args.logger.important_info('{}, graph compile time={:.2f}s'.format(args.backbone, time_for_graph_compile))
if i % args.log_interval == 0 and args.local_rank == 0:
time_used = time.time() - t_end
epoch = int(i / args.steps_per_epoch)
fps = args.per_batch_size * (i - old_progress) * args.world_size / time_used
args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
t_end = time.time()
loss_meter.reset()
old_progress = i
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
epoch_time_used = time.time() - t_epoch
epoch = int(i / args.steps_per_epoch)
fps = args.per_batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
args.logger.info('=================================================')
args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
args.logger.info('=================================================')
t_epoch = time.time()
i += 1
args.logger.info('--------- trains out ---------')
if __name__ == "__main__":
train()