!21077 facerecognition on gpu

Merge pull request !21077 from 郑彬/facerecognition_gpu
This commit is contained in:
i-robot 2021-08-17 07:39:15 +00:00 committed by Gitee
commit f919a6f78c
10 changed files with 341 additions and 73 deletions

View File

@ -13,7 +13,7 @@
# [Face Recognition Description](#contents)
This is a face recognition network based on Resnet, with support for training and evaluation on Ascend910.
This is a face recognition network based on Resnet, with support for training and evaluation on Ascend910, CPU or GPU.
ResNet (residual neural network) was proposed by Kaiming He and other four Chinese of Microsoft Research Institute. Through the use of ResNet unit, it successfully trained 152 layers of neural network, and won the championship in ilsvrc2015. The error rate on top 5 was 3.57%, and the parameter quantity was lower than vggnet, so the effect was very outstanding. Traditional convolution network or full connection network will have more or less information loss. At the same time, it will lead to the disappearance or explosion of gradient, which leads to the failure of deep network training. ResNet solves this problem to a certain extent. By passing the input information to the output, the integrity of the information is protected. The whole network only needs to learn the part of the difference between input and output, which simplifies the learning objectives and difficulties.The structure of ResNet can accelerate the training of neural network very quickly, and the accuracy of the model is also greatly improved. At the same time, ResNet is very popular, even can be directly used in the concept net network.
@ -55,8 +55,8 @@ The directory structure is as follows:
# [Environment Requirements](#contents)
- HardwareAscend, CPU
- Prepare hardware environment with Ascend processor. It also supports the use of CPU processor to prepare the
- HardwareAscend, CPU, GPU
- Prepare hardware environment with Ascend processor. It also supports the use of CPU or GPU processor to prepare the
hardware environment.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
@ -71,16 +71,20 @@ The directory structure is as follows:
The entire code structure is as following:
```python
└─ face_recognition
└─ FaceRecognition
├── ascend310_infer
├── README.md // descriptions about face_recognition
├── scripts
│ ├── run_distribute_train_base.sh // shell script for distributed training on Ascend
│ ├── run_distribute_train_beta.sh // shell script for distributed training on Ascend
│ ├── run_distribute_train_for_gpu.sh // shell script for distributed training on GPU
│ ├── run_eval.sh // shell script for evaluation on Ascend
│ ├── run_eval_cpu.sh // shell script for evaluation on CPU
│ ├── run_eval_gpu.sh // shell script for evaluation on gpu
│ ├── run_export.sh // shell script for exporting air model
│ ├── run_standalone_train_base.sh // shell script for standalone training on Ascend
│ ├── run_standalone_train_beta.sh // shell script for standalone training on Ascend
│ ├── run_standalone_train_for_gpu.sh // shell script for standalone training on GPU
│ ├── run_train_base_cpu.sh // shell script for training on CPU
│ ├── run_train_btae_cpu.sh // shell script for training on CPU
├── src
@ -97,7 +101,7 @@ The entire code structure is as following:
│ ├── lrsche_factory.py // learning rate schedule
│ ├── me_init.py // network parameter init method
│ ├── metric_factory.py // metric fc layer
── utils
── model_utils
│ ├── __init__.py // init file
│ ├── config.py // parameter analysis
│ ├── device_adapter.py // device adapter
@ -124,58 +128,98 @@ The entire code structure is as following:
```bash
cd ./scripts
sh run_standalone_train_base.sh [USE_DEVICE_ID]
bash run_standalone_train_base.sh [USE_DEVICE_ID]
```
for example:
```bash
cd ./scripts
sh run_standalone_train_base.sh 0
bash run_standalone_train_base.sh 0
```
- beta model
```bash
cd ./scripts
sh run_standalone_train_beta.sh [USE_DEVICE_ID]
bash run_standalone_train_beta.sh [USE_DEVICE_ID]
```
for example:
```bash
cd ./scripts
sh run_standalone_train_beta.sh 0
bash run_standalone_train_beta.sh 0
```
- Distribute mode (recommended)
- Stand alone mode(GPU)
- base/beta model
```bash
cd ./scripts
bash run_standalone_train_for_gpu.sh [base/beta] [DEVICE_ID](optional)
```
for example:
```bash
#base
cd ./scripts
bash run_standalone_train_for_gpu.sh base 3
#beta
cd ./scripts
bash run_standalone_train_for_gpu.sh beta 3
```
- Distribute mode (Ascend, recommended)
- base model
```bash
cd ./scripts
sh run_distribute_train_base.sh [RANK_TABLE]
bash run_distribute_train_base.sh [RANK_TABLE]
```
for example:
```bash
cd ./scripts
sh run_distribute_train_base.sh ./rank_table_8p.json
bash run_distribute_train_base.sh ./rank_table_8p.json
```
- beta model
```bash
cd ./scripts
sh run_distribute_train_beta.sh [RANK_TABLE]
bash run_distribute_train_beta.sh [RANK_TABLE]
```
for example:
```bash
cd ./scripts
sh run_distribute_train_beta.sh ./rank_table_8p.json
bash run_distribute_train_beta.sh ./rank_table_8p.json
```
- Distribute mode (GPU)
- base model
```bash
cd ./scripts
bash run_distribute_train_for_gpu.sh [RANK_SIZE] [base/beta] [CONFIG_PATH](optional)
```
for example:
```bash
#base
cd ./scripts
bash run_distribute_train_for_gpu.sh 8 base
#beta
cd ./scripts
bash run_distribute_train_for_gpu.sh 8 beta
```
- Stand alone mode(CPU)
@ -184,28 +228,28 @@ The entire code structure is as following:
```bash
cd ./scripts
sh run_train_base_cpu.sh
bash run_train_base_cpu.sh
```
for example:
```bash
cd ./scripts
sh run_train_base_cpu.sh
bash run_train_base_cpu.sh
```
- beta model
```bash
cd ./scripts
sh run_train_beta_cpu.sh
bash run_train_beta_cpu.sh
```
for example:
```bash
cd ./scripts
sh run_train_beta_cpu.sh
bash run_train_beta_cpu.sh
```
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
@ -352,34 +396,34 @@ You will get the result as following in "./scripts/acc.log" if 'dis_dataset' ran
### Training Performance
| Parameters | Face Recognition |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | V1 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 |
| uploaded Date | 09/30/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 4.7 million images |
| Training Parameters | epoch=100, batch_size=192, momentum=0.9 |
| Optimizer | Momentum |
| Loss Function | Cross Entropy |
| outputs | probability |
| Speed | 1pc: 350-600 fps; 8pcs: 2500-4500 fps |
| Total time | 1pc: NA hours; 8pcs: 10 hours |
| Checkpoint for Fine tuning | 584M (.ckpt file) |
| Parameters | Face Recognition | Face Recognition |
| -------------------------- | ----------------------------------------------------------- | ------------------ |
| Model Version | V1 | V1 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G |
| uploaded Date | 09/30/2020 (month/day/year) | 29/07/2021 (month/day/year) |
| MindSpore Version | 1.0.0 | 1.3.0 |
| Dataset | 4.7 million images | 4.7 million images |
| Training Parameters | epoch=100, batch_size=192, momentum=0.9 | epoch=18(base:9, beta:9), batch_size=192, momentum=0.9 |
| Optimizer | Momentum | Momentum |
| Loss Function | Cross Entropy | Cross Entropy |
| outputs | probability | probability |
| Speed | 1pc: 350-600 fps; 8pcs: 2500-4500 fps | base: 1pc: 310-360 fps, 8pcs: 2000-2500 fps; beta: 1pc: 420-470 fps, 8pcs: 3000-3500 fps; |
| Total time | 1pc: NA hours; 8pcs: 10 hours | 1pc: NA hours; 8pcs: 5.5(base) + 3.7(beta) hours |
| Checkpoint for Fine tuning | 584M (.ckpt file) | 768M (.ckpt file, base), 582M (.ckpt file, beta) |
### Evaluation Performance
| Parameters |Face Recognition For Tracking|
| ------------------- | --------------------------- |
| Model Version | V1 |
| Resource | Ascend 910; OS Euler2.8 |
| Uploaded Date | 09/30/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | 1.1 million images |
| batch_size | 512 |
| outputs | ACC |
| ACC | 0.9 |
| Model for inference | 584M (.ckpt file) |
| Parameters | Face Recognition | Face Recognition |
| ------------------- | --------------------------- | --------------------------- |
| Model Version | V1 | V1 |
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
| Uploaded Date | 09/30/2020 (month/day/year) | 29/07/2021 (month/day/year) |
| MindSpore Version | 1.0.0 | 1.3.0 |
| Dataset | 1.1 million images | 1.1 million images |
| batch_size | 512 | 512 |
| outputs | ACC | ACC |
| ACC | 0.9 | 0.9 |
| Model for inference | 584M (.ckpt file) | 582M (.ckpt file) |
# [ModelZoo Homepage](#contents)

View File

@ -20,6 +20,7 @@ from pprint import pformat
import numpy as np
import cv2
from mindspore.common import dtype as mstype
import mindspore.dataset.transforms.py_transforms as transforms
import mindspore.dataset.vision.py_transforms as vision
import mindspore.dataset as de
@ -127,7 +128,6 @@ def get_model(args):
net = get_backbone(args)
if args.fp16:
net.add_flags_recursive(fp16=True)
if args.weight.endswith('.ckpt'):
param_dict = load_checkpoint(args.weight)
param_dict_new = {}
@ -143,6 +143,8 @@ def get_model(args):
else:
args.logger.info('ERROR, not support file:{}, please check weight in config.py'.format(args.weight))
return 0
if args.device_target == 'GPU':
net.to_float(mstype.float32)
net.set_train(False)
return net

View File

@ -23,7 +23,7 @@ from mindspore.train.serialization import export, load_checkpoint, load_param_in
from src.backbone.resnet import get_backbone
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id
def modelarts_pre_process():
'''modelarts pre process function.'''
@ -41,8 +41,8 @@ def run_export():
config.backbone = config.export_backbone
config.use_drop = config.export_use_drop
devid = 0
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False,
device_id=get_device_id())
network = get_backbone(config)

View File

@ -0,0 +1,72 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]; then
echo "Usage: bash run_distribute_train_for_gpu.sh [RANK_SIZE] [base/beta] [CUDA_VISIBLE_DEVICES(0,1,2,3,4,5,6,7)](optional)"
exit 1
fi
expr $1 + 0 &>/dev/null
if [ $? != 0 ]
then
echo "error:RANK_SIZE=$1 is not a integer"
exit 1
fi
if [ $2 = "base" ]; then
CONFIG_PATH='./base_config.yaml'
elif [ $2 = "beta" ]; then
CONFIG_PATH='./beta_config.yaml'
else
echo "error: the train_stage is neither base nor beta"
exit 1
fi
if [ $# != 3 ]; then
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
else
export CUDA_VISIBLE_DEVICES=$3
fi
RANK_SIZE=$1
TRAIN_STAGE=$2
TRAIN_OUTPUT=./distribute_train_for_gpu_$TRAIN_STAGE
EXECUTE_PATH=$(pwd)
echo *******************EXECUTE_PATH=$EXECUTE_PATH
echo TRAIN_OUTPUT=$TRAIN_OUTPUT
echo CONFIG_PATH=$CONFIG_PATH
echo RANK_SIZE=$RANK_SIZE
echo TRAIN_STAGE=$TRAIN_STAGE
echo '*********************************************'
if [ -d $TRAIN_OUTPUT ]; then
rm -rf $TRAIN_OUTPUT
fi
mkdir $TRAIN_OUTPUT
cp ../train.py $TRAIN_OUTPUT
cp ../*.yaml $TRAIN_OUTPUT
cp -r ../model_utils $TRAIN_OUTPUT
cp -r ../src $TRAIN_OUTPUT
cd $TRAIN_OUTPUT || exit
env > env.log
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py \
--config_path=$CONFIG_PATH \
--device_target=GPU \
--train_stage=$TRAIN_STAGE \
--is_distributed=1 &> train_distribute_for_gpu.log &
cd ..

View File

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -gt 1 ]; then
echo "Usage: run_eval_gpu.sh [USE_DEVICE_ID](optional)"
exit 1
fi
dirname_path=$(dirname "$(pwd)")
echo ${dirname_path}
export PYTHONPATH=${dirname_path}:$PYTHONPATH
if [ $# -eq 1 ]; then
USE_DEVICE_ID=$1
else
USE_DEVICE_ID=0
fi
echo 'start device '$USE_DEVICE_ID
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=0
export CUDA_VISIBLE_DEVICES=$dev
EXECUTE_PATH=$(pwd)
echo *******************EXECUTE_PATH= $EXECUTE_PATH
if [ -d "${EXECUTE_PATH}/log_inference" ]; then
echo "[INFO] Delete old log_inference log files"
rm -rf ${EXECUTE_PATH}/log_inference
fi
mkdir ${EXECUTE_PATH}/log_inference
cd ${EXECUTE_PATH}/log_inference || exit
env > ${EXECUTE_PATH}/log_inference/face_recognition.log
python ${EXECUTE_PATH}/../eval.py \
--config_path=${EXECUTE_PATH}/../inference_config.yaml \
--device_target=GPU &> ${EXECUTE_PATH}/log_inference/face_recognition.log &
echo "[INFO] Start inference..."

View File

@ -14,9 +14,9 @@
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
if [ $# != 3 ] && [ $# != 2 ]
then
echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
echo "Usage: sh run_export.sh [PRETRAINED_BACKBONE] [DEVICE_TARGET] [USE_DEVICE_ID](optional)"
exit 1
fi
@ -42,9 +42,13 @@ SCRIPT_NAME='export.py'
ulimit -c unlimited
BATCH_SIZE=$1
USE_DEVICE_ID=$2
PRETRAINED_BACKBONE=$(get_real_path $3)
PRETRAINED_BACKBONE=$(get_real_path $1)
DEVICE_TARGET=$2
if [ $# = 3 ]; then
USE_DEVICE_ID=$3
else
USE_DEVICE_ID=0
fi
if [ ! -f $PRETRAINED_BACKBONE ]
then
@ -52,7 +56,6 @@ if [ ! -f $PRETRAINED_BACKBONE ]
exit 1
fi
echo $BATCH_SIZE
echo $USE_DEVICE_ID
echo $PRETRAINED_BACKBONE
@ -65,7 +68,8 @@ cd ${current_exec_path}/device$USE_DEVICE_ID || exit
dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--config_path=${dirname_path}/beta_config.yaml \
--pretrained=$PRETRAINED_BACKBONE \
--batch_size=$BATCH_SIZE > convert.log 2>&1 &
--device_target=$DEVICE_TARGET > convert.log 2>&1 &
echo 'running'

View File

@ -0,0 +1,70 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 1 ]; then
echo "Usage: bash run_standalone_train_for_gpu.sh [base/beta] [DEVICE_ID](optional)"
exit 1
fi
expr $2 + 6 &>/dev/null
if [ $? != 0 ]
then
echo "error:DEVICE_ID=$2 is not a integer"
exit 1
fi
if [ $# -eq 2 ]; then
DEVICE_ID=$2
else
DEVICE_ID=0
fi
if [ $1 = "base" ]; then
CONFIG_PATH='./base_config.yaml'
elif [ $1 = "beta" ]; then
CONFIG_PATH='./beta_config.yaml'
else
echo "error: the train_stage is neither base nor beta"
exit 1
fi
export CUDA_VISIBLE_DEVICES=$DEVICE_ID
export DEVICE_ID=0
TRAIN_STAGE=$1
TRAIN_OUTPUT=./standalone_train_for_gpu_$TRAIN_STAGE
EXECUTE_PATH=$(pwd)
echo *******************EXECUTE_PATH=$EXECUTE_PATH
echo TRAIN_OUTPUT=$TRAIN_OUTPUT
echo CONFIG_PATH=$CONFIG_PATH
echo TRAIN_STAGE=$TRAIN_STAGE
echo '*********************************************'
if [ -d $TRAIN_OUTPUT ]; then
rm -rf $TRAIN_OUTPUT
fi
mkdir $TRAIN_OUTPUT
cp ../train.py $TRAIN_OUTPUT
cp ../*.yaml $TRAIN_OUTPUT
cp -r ../model_utils $TRAIN_OUTPUT
cp -r ../src $TRAIN_OUTPUT
cd $TRAIN_OUTPUT || exit
python train.py \
--config_path=$CONFIG_PATH \
--device_target=GPU \
--train_stage=$TRAIN_STAGE \
--is_distributed=0 &> train_standalone_for_gpu.log &
cd ..

View File

@ -162,6 +162,8 @@ class ImageFolderDataset:
with open(cache_path, 'wb') as fw:
pickle.dump(cache, fw)
print('local dump cache:{}'.format(cache_path))
with open(cache_path[:cache_path.rfind('.')] + 'txt', 'w') as _f:
_f.write("Rank 0 dump data to cache_path:'{}' successfully!".format(cache_path))
else:
with open(cache_path, 'wb') as fw:
pickle.dump(cache, fw)

View File

@ -21,18 +21,16 @@ import mindspore.dataset as de
import mindspore.dataset.vision.py_transforms as F
import mindspore.dataset.transforms.py_transforms as F2
from model_utils.config import config
from src.custom_dataset import DistributedCustomSampler, CustomDataset
__all__ = ['get_de_dataset']
def get_de_dataset(args):
'''get_de_dataset'''
lbl_transforms = [F.ToType(np.int32)]
transform_label = F2.Compose(lbl_transforms)
drop_remainder = False
drop_remainder = True
transforms = [F.ToPIL(),
F.RandomHorizontalFlip(),
@ -40,16 +38,21 @@ def get_de_dataset(args):
F.Normalize(mean=[0.5], std=[0.5])]
transform = F2.Compose(transforms)
cache_path = os.path.join('cache', os.path.basename(args.data_dir), 'data_cache.pkl')
print(cache_path)
if args.device_target == 'GPU' and args.local_rank != 0:
while True:
if os.path.exists(cache_path) and os.path.exists(cache_path[:cache_path.rfind('.')] + 'txt'):
break
with open(cache_path[:cache_path.rfind('.')] + 'txt') as _f:
args.logger.info(_f.readline())
if not os.path.exists(os.path.dirname(cache_path)):
os.makedirs(os.path.dirname(cache_path))
dataset = CustomDataset(args.data_dir, cache_path, args.is_distributed)
args.logger.info("dataset len:{}".format(dataset.__len__()))
if config.device_target == 'Ascend':
if args.device_target in ('Ascend', 'GPU'):
sampler = DistributedCustomSampler(dataset, num_replicas=args.world_size, rank=args.local_rank,
is_distributed=args.is_distributed)
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
elif config.device_target == 'CPU':
elif args.device_target == 'CPU':
de_dataset = de.GeneratorDataset(dataset, ["image", "label"])
args.logger.info("after sampler de_dataset datasize :{}".format(de_dataset.get_dataset_size()))
de_dataset = de_dataset.map(input_columns="image", operations=transform)

View File

@ -20,7 +20,7 @@ import mindspore
from mindspore.nn import Cell
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import init
from mindspore.communication.management import init, get_group_size, get_rank
from mindspore.nn.optim import Momentum
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
@ -42,7 +42,11 @@ from model_utils.device_adapter import get_device_id, get_device_num, get_rank_i
mindspore.common.seed.set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False,
device_id=get_device_id(), reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
reserve_class_name_in_scope=False, enable_graph_kernel=config.device_target == "GPU")
if config.device_target == 'Ascend':
context.set_context(enable_auto_mixed_precision=False)
if config.device_target != 'GPU' or not config.is_distributed:
context.set_context(device_id=get_device_id())
class DistributedHelper(Cell):
'''DistributedHelper'''
@ -175,15 +179,38 @@ def modelarts_pre_process():
config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.ckpt_path)
def model_context():
"""set context for facerecognition"""
if config.is_distributed:
parallel_mode = ParallelMode.HYBRID_PARALLEL if config.device_target == 'Ascend' else ParallelMode.DATA_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
if config.is_distributed:
if config.device_target == 'Ascend':
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=config.world_size, gradients_mean=True)
init()
config.local_rank = get_rank_id()
config.world_size = get_device_num()
elif config.device_target == 'GPU':
init()
device_num = get_group_size()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=parallel_mode,
gradients_mean=True)
config.world_size = get_group_size()
config.local_rank = get_rank()
else:
pass
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
'''run train function.'''
config.local_rank = get_rank_id()
config.world_size = get_device_num()
model_context()
log_path = os.path.join(config.ckpt_path, 'logs')
config.logger = get_logger(log_path, config.local_rank)
support_train_stage = ['base', 'beta']
if config.train_stage.lower() not in support_train_stage:
config.logger.info('your train stage is not support.')
@ -192,13 +219,6 @@ def run_train():
if not os.path.exists(config.data_dir):
config.logger.info('ERROR, data_dir is not exists, please set data_dir in config.py')
raise ValueError('ERROR, data_dir is not exists, please set data_dir in config.py')
parallel_mode = ParallelMode.HYBRID_PARALLEL if config.is_distributed else ParallelMode.STAND_ALONE
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=config.world_size, gradients_mean=True)
if config.is_distributed:
init()
if config.local_rank % 8 == 0:
if not os.path.exists(config.ckpt_path):
os.makedirs(config.ckpt_path)
@ -260,7 +280,7 @@ def run_train():
scale_window=2000)
if config.device_target == "Ascend":
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=scale_manager)
elif config.device_target == "CPU":
elif config.device_target in ("CPU", "GPU"):
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=None)
save_checkpoint_steps = config.ckpt_steps