!18694 modify FaceQuality network for clould

Merge pull request !18694 from zhanghuiyao/FaceQuality_clould
This commit is contained in:
i-robot 2021-06-23 03:15:58 +00:00 committed by Gitee
commit a97c7cfed6
16 changed files with 633 additions and 150 deletions

View File

@ -85,10 +85,16 @@ We use about 122K face images as training dataset and 2K as evaluating dataset i
The entire code structure is as following:
```python
```text
.
└─ Face Quality Assessment
├─ README.md
├─ model_utils
├─ __init__.py # module init file
├─ config.py # Parse arguments
├─ device_adapter.py # Device adapter for ModelArts
├─ local_adapter.py # Local adapter
└─ moxing_adapter.py # Moxing adapter for ModelArts
├─ scripts
├─ run_standalone_train.sh # launch standalone training(1p) in ascend
├─ run_distribute_train.sh # launch distributed training(8p) in ascend
@ -102,12 +108,12 @@ The entire code structure is as following:
├─ run_eval_cpu.sh # launch evaluating in cpu
└─ run_export_cpu.sh # launch exporting mindir model in cpu
├─ src
├─ config.py # parameter configuration
├─ dataset.py # dataset loading and preprocessing for training
├─ face_qa.py # network backbone
├─ log.py # log function
├─ loss_factory.py # loss function
└─ lr_generator.py # generate learning rate
├─ default_config.yaml # Configurations
├─ train.py # training scripts
├─ eval.py # evaluation scripts
└─ export.py # export air model
@ -225,6 +231,95 @@ epoch[39], iter[19100], loss:2.140766, 8088.52 imgs/sec
epoch[39], iter[19110], loss:2.111101, 8791.05 imgs/sec
```
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
```bash
# Train 8p on ModelArts with Ascend
# (1) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file.
# Set "is_distributed=1" on default_config.yaml file.
# Set "per_batch_size=32" on default_config.yaml file.
# Set "train_label_file='/cache/data/face_quality_dataset/qa_300W_LP_train.txt'" on default_config.yaml file.
# (option) Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file if load pretrain.
# (option) Set "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file if load pretrain.
# Set other parameters on default_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "is_distributed=1" on the website UI interface.
# Add "per_batch_size=32" on the website UI interface.
# Add "train_label_file=/cache/data/face_quality_dataset/qa_300W_LP_train.txt" on the website UI interface.
# (option) Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface if load pretrain.
# (option) Add "pretrained=/cache/checkpoint_path/model.ckpt" on the website UI interface if load pretrain.
# Add other parameters on the website UI interface.
# (2) (option) Upload or copy your pretrained model to S3 bucket if load pretrain.
# (3) Modify imagepath on "/dir_to_your_dataset/qa_300W_LP_train.txt" file.
# (4) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
# (5) Set the code directory to "/path/FaceQualityAssessment" on the website UI interface.
# (6) Set the startup file to "train.py" on the website UI interface.
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (8) Create your job.
#
# Train 1p on ModelArts with Ascend
# (1) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file.
# Set "is_distributed=0" on default_config.yaml file.
# Set "per_batch_size=256" on default_config.yaml file.
# Set "train_label_file='/cache/data/face_quality_dataset/qa_300W_LP_train.txt'" on default_config.yaml file.
# (option) Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file if load pretrain.
# (option) Set "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file if load pretrain.
# Set other parameters on default_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "is_distributed=0" on the website UI interface.
# Add "per_batch_size=256" on the website UI interface.
# Add "train_label_file=/cache/data/face_quality_dataset/qa_300W_LP_train.txt" on the website UI interface.
# (option) Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface if load pretrain.
# (option) Add "pretrained=/cache/checkpoint_path/model.ckpt" on the website UI interface if load pretrain.
# Add other parameters on the website UI interface.
# (2) (option) Upload or copy your pretrained model to S3 bucket if load pretrain.
# (3) Modify imagepath on "/dir_to_your_dataset/qa_300W_LP_train.txt" file.
# (4) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
# (5) Set the code directory to "/path/FaceQualityAssessment" on the website UI interface.
# (6) Set the startup file to "train.py" on the website UI interface.
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (8) Create your job.
#
# Eval 1p on ModelArts with Ascend
# (1) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file.
# Set "eval_dir='/cache/data/face_quality_dataset/AFLW2000'" on default_config.yaml file.
# Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file.
# Set "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file.
# Set other parameters on default_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "eval_dir=/cache/data/face_quality_dataset/AFLW2000" on the website UI interface.
# Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface.
# Add "pretrained=/cache/checkpoint_path/model.ckpt" on the website UI interface.
# Add other parameters on the website UI interface.
# (2) Upload or copy your trained model to S3 bucket.
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
# (4) Set the code directory to "/path/FaceQualityAssessment" on the website UI interface.
# (5) Set the startup file to "eval.py" on the website UI interface.
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (7) Create your job.
#
# Export 1p on ModelArts with Ascend
# (1) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file.
# Set "batch_size=8" on default_config.yaml file.
# Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file.
# Set "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file.
# Set other parameters on default_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "batch_size=8" on the website UI interface.
# Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface.
# Add "pretrained=/cache/checkpoint_path/model.ckpt" on the website UI interface.
# Add other parameters on the website UI interface.
# (2) Upload or copy your trained model to S3 bucket.
# (3) Set the code directory to "/path/FaceQualityAssessment" on the website UI interface.
# (4) Set the startup file to "export.py" on the website UI interface.
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (6) Create your job.
```
### Evaluation
```bash

View File

@ -0,0 +1,64 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
need_modelarts_dataset_unzip: True
modelarts_dataset_unzip_name: "face_quality_dataset"
# ==============================================================================
# options
task: 'face_qa'
# dataset related
per_batch_size: 256 # if run 1p
#per_batch_size: 32 # if run 8p
# network structure related
steps_per_epoch: 0
loss_scale: 1024
# optimizer related
lr: 0.02
lr_scale: 1
lr_epochs: '10, 20, 30'
weight_decay: 0.0005
momentum: 0.9
max_epoch: 40
warmup_epochs: 0
pretrained: ''
# logging related
log_interval: 10
ckpt_path: './output'
ckpt_interval: 500
# train option
is_distributed: 0
train_label_file: ''
# eval option
eval_dir: ''
# export option
batch_size: 8
file_name: 'FaceQualityAssessment'
file_format: 'AIR'
---
# Help description for each configuration
is_distributed: "if multi device"
train_label_file: "image label list file, e.g. /home/label.txt"
pretrained: "pretrained model to load"
device_target: "device target, choices in ['Ascend', 'GPU', 'CPU']"
eval_dir: "eval image dir, e.g. /home/test"
batch_size: "batch size for export"
file_name: "output file name"
file_format: "file format, choices in ['AIR', 'ONNX', 'MINDIR']"

View File

@ -14,8 +14,8 @@
# ============================================================================
"""Face Quality Assessment eval."""
import os
import time
import warnings
import argparse
import numpy as np
import cv2
from tqdm import tqdm
@ -28,6 +28,10 @@ from mindspore import context
from src.face_qa import FaceQABackbone
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num
warnings.filterwarnings('ignore')
@ -99,11 +103,64 @@ reshape = P.Reshape()
argmax = P.ArgMaxWithValue()
def test_trains(args):
'''test trains'''
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
'''run eval'''
print('----eval----begin----')
model_path = args.pretrained
model_path = config.pretrained
result_file = model_path.replace('.ckpt', '.txt')
if os.path.exists(result_file):
os.remove(result_file)
@ -130,7 +187,7 @@ def test_trains(args):
print('wrong model path')
return 1
path = args.eval_dir
path = config.eval_dir
kp_error_all = [[], [], [], [], []]
eulers_error_all = [[], [], []]
kp_ipn = []
@ -205,17 +262,8 @@ def test_trains(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Face Quality Assessment')
parser.add_argument('--eval_dir', type=str, default='', help='eval image dir, e.g. /home/test')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
help='device target')
arg = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=arg.device_target, save_graphs=False)
if arg.device_target == 'Ascend':
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
if config.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)
test_trains(arg)
run_eval()

View File

@ -14,7 +14,6 @@
# ============================================================================
"""Convert ckpt to air/mindir."""
import os
import argparse
import numpy as np
from mindspore import context
@ -23,10 +22,25 @@ from mindspore.train.serialization import export, load_checkpoint, load_param_in
from src.face_qa import FaceQABackbone
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
def modelarts_pre_process():
'''modelarts pre process function.'''
config.file_name = os.path.join(config.output_path, config.file_name)
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_export():
'''run export.'''
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
if config.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)
def main(args):
network = FaceQABackbone()
ckpt_path = args.pretrained
ckpt_path = config.pretrained
if os.path.isfile(ckpt_path):
param_dict = load_checkpoint(ckpt_path)
param_dict_new = {}
@ -42,28 +56,12 @@ def main(args):
else:
print('-----------------------load model failed -----------------------')
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 96, 96)).astype(np.float32)
input_data = np.random.uniform(low=0, high=1.0, size=(config.batch_size, 3, 96, 96)).astype(np.float32)
tensor_input_data = Tensor(input_data)
export(network, tensor_input_data, file_name=args.file_name, file_format=args.file_format)
export(network, tensor_input_data, file_name=config.file_name, file_format=config.file_format)
print('-----------------------export model success-----------------------')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Convert ckpt to air/mindir')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
help='device target')
parser.add_argument('--file_name', type=str, default='FaceQualityAssessment', help='output file name')
parser.add_argument('--file_format', type=str, choices=['AIR', 'ONNX', 'MINDIR'], default='AIR', help='file format')
arg = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=arg.device_target, save_graphs=False)
if arg.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)
main(arg)
run_export()

View File

@ -0,0 +1,126 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,116 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
# Run the main function
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -80,6 +80,7 @@ do
dev=`expr $i + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--per_batch_size=32 \
--is_distributed=1 \
--train_label_file=$TRAIN_LABEL_FILE \
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &

View File

@ -47,12 +47,14 @@ then
mpirun -n $1 --allow-run-as-root python ${BASEPATH}/../train.py \
--train_label_file=$3 \
--is_distributed=1 \
--per_batch_size=32 \
--device_target='GPU' \
--pretrained=$4 > train.log 2>&1 &
else
python ${BASEPATH}/../train.py \
--train_label_file=$3 \
--is_distributed=0 \
--per_batch_size=256 \
--device_target='GPU' \
--pretrained=$4 > train.log 2>&1 &
fi
@ -62,11 +64,13 @@ else
mpirun -n $1 --allow-run-as-root python ${BASEPATH}/../train.py \
--train_label_file=$3 \
--is_distributed=1 \
--per_batch_size=32 \
--device_target='GPU' > train.log 2>&1 &
else
python ${BASEPATH}/../train.py \
--train_label_file=$3 \
--is_distributed=0 \
--per_batch_size=256 \
--device_target='GPU' > train.log 2>&1 &
fi
fi

View File

@ -77,6 +77,7 @@ dev=`expr $USE_DEVICE_ID + 0`
export DEVICE_ID=$dev
python ${dirname_path}/${SCRIPT_NAME} \
--is_distributed=0 \
--per_batch_size=256 \
--train_label_file=$TRAIN_LABEL_FILE \
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &

View File

@ -31,11 +31,13 @@ cd ${current_exec_path}/cpu || exit
if [ $2 ] # pretrained ckpt
then
python ${BASEPATH}/../train.py \
--per_batch_size=256 \
--train_label_file=$1 \
--device_target='CPU' \
--pretrained=$2 > train.log 2>&1 &
else
python ${BASEPATH}/../train.py \
--per_batch_size=256 \
--train_label_file=$1 \
--device_target='CPU' > train.log 2>&1 &
fi

View File

@ -32,10 +32,12 @@ if [ $2 ] # pretrained ckpt
then
python ${BASEPATH}/../train.py \
--train_label_file=$1 \
--per_batch_size=256 \
--device_target='GPU' \
--pretrained=$2 > train.log 2>&1 &
else
python ${BASEPATH}/../train.py \
--train_label_file=$1 \
--per_batch_size=256 \
--device_target='GPU' > train.log 2>&1 &
fi

View File

@ -1,76 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network config setting, will be used in train.py and eval.py"""
from easydict import EasyDict as edict
faceqa_1p_cfg = edict({
'task': 'face_qa',
# dataset related
'per_batch_size': 256,
# network structure related
'steps_per_epoch': 0,
'loss_scale': 1024,
# optimizer related
'lr': 0.02,
'lr_scale': 1,
'lr_epochs': '10, 20, 30',
'weight_decay': 0.0005,
'momentum': 0.9,
'max_epoch': 40,
'warmup_epochs': 0,
'pretrained': '',
'local_rank': 0,
'world_size': 1,
# logging related
'log_interval': 10,
'ckpt_path': '../../output',
'ckpt_interval': 500,
'device_id': 0,
})
faceqa_8p_cfg = edict({
'task': 'face_qa',
# dataset related
'per_batch_size': 32,
# network structure related
'steps_per_epoch': 0,
'loss_scale': 1024,
# optimizer related
'lr': 0.02,
'lr_scale': 1,
'lr_epochs': '10, 20, 30',
'weight_decay': 0.0005,
'momentum': 0.9,
'max_epoch': 40,
'warmup_epochs': 0,
'pretrained': '',
'local_rank': 0,
'world_size': 8,
# logging related
'log_interval': 10, # 10
'ckpt_path': '../../output',
'ckpt_interval': 500,
})

View File

@ -16,7 +16,6 @@
import os
import time
import datetime
import argparse
import warnings
import numpy as np
@ -31,37 +30,95 @@ from mindspore.nn.optim import Momentum
from mindspore.communication.management import get_group_size, init, get_rank
from src.loss import CriterionsFaceQA
from src.config import faceqa_1p_cfg, faceqa_8p_cfg
from src.face_qa import FaceQABackbone, BuildTrainNetwork
from src.lr_generator import warmup_step
from src.dataset import faceqa_dataset
from src.log import get_logger, AverageMeter
from model_utils.config import config as cfg
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num
warnings.filterwarnings('ignore')
mindspore.common.seed.set_seed(1)
def main(args):
if args.is_distributed == 0:
cfg = faceqa_1p_cfg
else:
cfg = faceqa_8p_cfg
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, cfg.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
cfg.data_lst = args.train_label_file
cfg.pretrained = args.pretrained
if cfg.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(cfg.data_path, cfg.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(cfg.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
cfg.ckpt_path = os.path.join(cfg.output_path, cfg.ckpt_path)
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
'''run train.'''
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target, save_graphs=False)
if cfg.device_target == 'Ascend':
context.set_context(device_id=get_device_id())
cfg.data_lst = cfg.train_label_file
# Init distributed
if args.is_distributed:
if cfg.is_distributed:
init()
cfg.local_rank = get_rank()
cfg.world_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
else:
cfg.local_rank = 0
cfg.world_size = 1
parallel_mode = ParallelMode.STAND_ALONE
# parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.world_size,
gradients_mean=True)
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.world_size, gradients_mean=True)
mindspore.common.set_seed(1)
@ -104,11 +161,8 @@ def main(args):
# optimizer and lr scheduler
lr = warmup_step(cfg, gamma=0.9)
opt = Momentum(params=network.trainable_params(),
learning_rate=lr,
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale)
opt = Momentum(params=network.trainable_params(), learning_rate=lr, momentum=cfg.momentum,
weight_decay=cfg.weight_decay, loss_scale=cfg.loss_scale)
# package training process, adjust lr + forward + backward + optimizer
train_net = BuildTrainNetwork(network, criterion)
@ -142,7 +196,6 @@ def main(args):
loss = train_net(data, gt)
loss_meter.update(loss.asnumpy())
# ckpt
if cfg.local_rank == 0:
cb_params.cur_step_num = i + 1 # current step number
cb_params.batch_num = i + 2
@ -175,18 +228,4 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Face Quality Assessment')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--train_label_file', type=str, default='', help='image label list file, e.g. /home/label.txt')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
help='device target')
arg = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=arg.device_target, save_graphs=False)
if arg.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)
main(arg)
run_train()