forked from mindspore-Ecosystem/mindspore
modify FaceAttribute network for clould
This commit is contained in:
parent
ec999d3fa5
commit
969da95546
|
@ -99,10 +99,16 @@ We use about 91K face images as training dataset and 11K as evaluating dataset i
|
|||
|
||||
The entire code structure is as following:
|
||||
|
||||
```python
|
||||
```text
|
||||
.
|
||||
└─ Face Attribute
|
||||
├─ README.md
|
||||
├── model_utils
|
||||
│ ├──__init__.py // module init file
|
||||
│ ├──config.py // Parse arguments
|
||||
│ ├──device_adapter.py // Device adapter for ModelArts
|
||||
│ ├──local_adapter.py // Local adapter
|
||||
│ ├──moxing_adapter.py // Moxing adapter for ModelArts
|
||||
├─ scripts
|
||||
├─ run_standalone_train.sh # launch standalone training(1p) in ascend
|
||||
├─ run_distribute_train.sh # launch distributed training(8p) in ascend
|
||||
|
@ -117,7 +123,6 @@ The entire code structure is as following:
|
|||
├─ resnet18.py # network backbone
|
||||
├─ head_factory_softmax.py # network head with softmax
|
||||
└─ resnet18_softmax.py # network backbone with softmax
|
||||
├─ config.py # parameter configuration
|
||||
├─ dataset_eval.py # dataset loading and preprocessing for evaluating
|
||||
├─ dataset_train.py # dataset loading and preprocessing for training
|
||||
├─ logging.py # log function
|
||||
|
@ -125,6 +130,9 @@ The entire code structure is as following:
|
|||
├─ data_to_mindrecord_train.py # convert dataset to mindrecord for training
|
||||
├─ data_to_mindrecord_train_append.py # add dataset to an existed mindrecord for training
|
||||
└─ data_to_mindrecord_eval.py # convert dataset to mindrecord for evaluating
|
||||
├─ default_config.yaml # Configurations
|
||||
├─ postprocess.py # postprocess scripts
|
||||
├─ preprocess.py # preprocess scripts
|
||||
├─ train.py # training scripts
|
||||
├─ eval.py # evaluation scripts
|
||||
└─ export.py # export air model
|
||||
|
@ -192,6 +200,89 @@ epoch[69], iter[6140], loss:1.158641, 9755.81 imgs/sec
|
|||
epoch[69], iter[6150], loss:1.167064, 9300.77 imgs/sec
|
||||
```
|
||||
|
||||
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
|
||||
|
||||
```bash
|
||||
# Train 8p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "mindrecord_path='/cache/data/face_attribute_dataset/train/data_train.mindrecord'" on default_config.yaml file.
|
||||
# (option) Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file if load pretrain.
|
||||
# (option) Set "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file if load pretrain.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "mindrecord_path=/cache/data/face_attribute_dataset/train/data_train.mindrecord" on the website UI interface.
|
||||
# (option) Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface if load pretrain.
|
||||
# (option) Add "pretrained=/cache/checkpoint_path/model.ckpt" on the website UI interface if load pretrain.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) (option) Upload or copy your pretrained model to S3 bucket if load pretrain.
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/FaceAttribute" on the website UI interface.
|
||||
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
#
|
||||
# Train 1p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "world_size=1" on default_config.yaml file.
|
||||
# Set "mindrecord_path='/cache/data/face_attribute_dataset/train/data_train.mindrecord'" on default_config.yaml file.
|
||||
# (option) Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file if load pretrain.
|
||||
# (option) Set "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file if load pretrain.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "world_size=1" on the website UI interface.
|
||||
# Add "mindrecord_path=/cache/data/face_attribute_dataset/train/data_train.mindrecord" on the website UI interface.
|
||||
# (option) Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface if load pretrain.
|
||||
# (option) Add "pretrained=/cache/checkpoint_path/model.ckpt" on the website UI interface if load pretrain.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) (option) Upload or copy your pretrained model to S3 bucket if load pretrain.
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/FaceAttribute" on the website UI interface.
|
||||
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
#
|
||||
# Eval 1p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "mindrecord_path='/cache/data/face_attribute_dataset/train/data_train.mindrecord'" on default_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file.
|
||||
# Set "model_path='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "mindrecord_path=/cache/data/face_attribute_dataset/train/data_train.mindrecord" on the website UI interface.
|
||||
# Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface.
|
||||
# Add "model_path=/cache/checkpoint_path/model.ckpt" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Upload or copy your trained model to S3 bucket.
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/FaceAttribute" on the website UI interface.
|
||||
# (5) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
#
|
||||
# Export 1p on ModelArts
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "file_name='faceattri'" on default_config.yaml file.
|
||||
# Set "file_format='MINDIR'" on default_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file.
|
||||
# Set "ckpt_file='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "file_name=faceattri" on the website UI interface.
|
||||
# Add "file_format=MINDIR" on the website UI interface.
|
||||
# Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface.
|
||||
# Add "ckpt_file=/cache/checkpoint_path/model.ckpt" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (2) Upload or copy your trained model to S3 bucket.
|
||||
# (3) Set the code directory to "/path/FaceAttribute" on the website UI interface.
|
||||
# (4) Set the startup file to "export.py" on the website UI interface.
|
||||
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (6) Create your job.
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```bash
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: False
|
||||
modelarts_dataset_unzip_name: ""
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
per_batch_size: 128
|
||||
dst_h: 112
|
||||
dst_w: 112
|
||||
workers: 8
|
||||
attri_num: 3
|
||||
classes: '9,2,2'
|
||||
backbone: 'resnet18'
|
||||
loss_scale: 1024
|
||||
flat_dim: 512
|
||||
fc_dim: 256
|
||||
lr: 0.009
|
||||
lr_scale: 1
|
||||
lr_epochs: [20, 30, 50]
|
||||
weight_decay: 0.0005
|
||||
momentum: 0.9
|
||||
max_epoch: 70
|
||||
warmup_epochs: 0
|
||||
log_interval: 10
|
||||
ckpt_path: './output'
|
||||
|
||||
# data_to_mindrecord parameter
|
||||
eval_dataset_txt_file: 'Your_label_txt_file'
|
||||
eval_mindrecord_file_name: 'Your_output_path/data_test.mindrecord'
|
||||
train_dataset_txt_file: 'Your_label_txt_file'
|
||||
train_mindrecord_file_name: 'Your_output_path/data_train.mindrecord'
|
||||
train_append_dataset_txt_file: 'Your_label_txt_file'
|
||||
train_append_mindrecord_file_name: 'Your_previous_output_path/data_train.mindrecord0'
|
||||
|
||||
# tran/eval/preprocess option
|
||||
mindrecord_path: ""
|
||||
pretrained: ""
|
||||
local_rank: 0
|
||||
world_size: 8
|
||||
model_path: ""
|
||||
|
||||
# export option
|
||||
ckpt_file: ""
|
||||
file_name: "faceattri"
|
||||
file_format: "MINDIR"
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
|
||||
# tran/eval/preprocess option
|
||||
mindrecord_path: "dataset path, e.g. /home/data.mindrecord"
|
||||
pretrained: "pretrained model to load"
|
||||
local_rank: "current rank to support distributed"
|
||||
world_size: "current process number to support distributed"
|
||||
model_path: "pretrained model to load"
|
||||
|
||||
# export option
|
||||
ckpt_file: "pretrained model to load"
|
||||
file_name: "file name"
|
||||
file_format: "file format, choices in ['MINDIR', 'AIR']"
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Face attribute eval."""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
|
@ -23,22 +23,20 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.dataset_eval import data_generator_eval
|
||||
from src.config import config
|
||||
from src.FaceAttribute.resnet18 import get_resnet18
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
|
||||
def softmax(x, axis=0):
|
||||
return np.exp(x) / np.sum(np.exp(x), axis=axis)
|
||||
|
||||
|
||||
def main(args):
|
||||
network = get_resnet18(args)
|
||||
ckpt_path = args.model_path
|
||||
if os.path.isfile(ckpt_path):
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
def load_pretrain(checkpoint, network):
|
||||
'''load pretrain model.'''
|
||||
if os.path.isfile(checkpoint):
|
||||
param_dict = load_checkpoint(checkpoint)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
|
@ -51,23 +49,77 @@ def main(args):
|
|||
print('-----------------------load model success-----------------------')
|
||||
else:
|
||||
print('-----------------------load model failed-----------------------')
|
||||
return network
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
'''run eval.'''
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=get_device_id())
|
||||
|
||||
network = get_resnet18(config)
|
||||
ckpt_path = config.model_path
|
||||
network = load_pretrain(ckpt_path, network)
|
||||
|
||||
network.set_train(False)
|
||||
|
||||
de_dataloader, steps_per_epoch, _ = data_generator_eval(args)
|
||||
de_dataloader, steps_per_epoch, _ = data_generator_eval(config)
|
||||
|
||||
total_data_num_age = 0
|
||||
total_data_num_gen = 0
|
||||
total_data_num_mask = 0
|
||||
age_num = 0
|
||||
gen_num = 0
|
||||
mask_num = 0
|
||||
gen_tp_num = 0
|
||||
mask_tp_num = 0
|
||||
gen_fp_num = 0
|
||||
mask_fp_num = 0
|
||||
gen_fn_num = 0
|
||||
mask_fn_num = 0
|
||||
total_data_num_age, total_data_num_gen, total_data_num_mask = 0, 0, 0
|
||||
age_num, gen_num, mask_num = 0, 0, 0
|
||||
gen_tp_num, mask_tp_num, gen_fp_num = 0, 0, 0
|
||||
mask_fp_num, gen_fn_num, mask_fn_num = 0, 0, 0
|
||||
for step_i, (data, gt_classes) in enumerate(de_dataloader):
|
||||
|
||||
print('evaluating {}/{} ...'.format(step_i + 1, steps_per_epoch))
|
||||
|
@ -98,11 +150,12 @@ def main(args):
|
|||
if gt_mask == mask:
|
||||
mask_num += 1
|
||||
|
||||
if gt_gen == 1 and gen == 1:
|
||||
if gen == 1:
|
||||
if gt_gen == 1:
|
||||
gen_tp_num += 1
|
||||
if gt_gen == 0 and gen == 1:
|
||||
elif gt_gen == 0:
|
||||
gen_fp_num += 1
|
||||
if gt_gen == 1 and gen == 0:
|
||||
elif gen == 0 and gt_gen == 1:
|
||||
gen_fn_num += 1
|
||||
|
||||
if gt_mask == 1 and mask == 1:
|
||||
|
@ -165,25 +218,6 @@ def main(args):
|
|||
ft.write('mask recall: {}\n'.format(mask_recall))
|
||||
ft.write('mask f1: {}\n'.format(mask_f1))
|
||||
|
||||
def parse_args():
|
||||
"""parse_args"""
|
||||
parser = argparse.ArgumentParser(description='face attributes eval')
|
||||
parser.add_argument('--model_path', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_1 = parse_args()
|
||||
|
||||
args_1.dst_h = config.dst_h
|
||||
args_1.dst_w = config.dst_w
|
||||
args_1.attri_num = config.attri_num
|
||||
args_1.classes = config.classes
|
||||
args_1.flat_dim = config.flat_dim
|
||||
args_1.fc_dim = config.fc_dim
|
||||
args_1.workers = config.workers
|
||||
|
||||
main(args_1)
|
||||
run_eval()
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
"""Convert ckpt to air."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
|
@ -22,15 +21,22 @@ from mindspore import Tensor
|
|||
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
|
||||
|
||||
from src.FaceAttribute.resnet18_softmax import get_resnet18
|
||||
from src.config import config
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
config.file_name = os.path.join(config.output_path, config.file_name)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_export():
|
||||
'''run export.'''
|
||||
devid = 0
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
|
||||
|
||||
def main(args):
|
||||
network = get_resnet18(args)
|
||||
ckpt_path = args.ckpt_file
|
||||
network = get_resnet18(config)
|
||||
ckpt_path = config.ckpt_file
|
||||
if os.path.isfile(ckpt_path):
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
param_dict_new = {}
|
||||
|
@ -49,27 +55,9 @@ def main(args):
|
|||
input_data = np.random.uniform(low=0, high=1.0, size=(1, 3, 112, 112)).astype(np.float32)
|
||||
tensor_input_data = Tensor(input_data)
|
||||
|
||||
export(network, tensor_input_data, file_name=args.file_name,
|
||||
file_format=args.file_format)
|
||||
export(network, tensor_input_data, file_name=config.file_name,
|
||||
file_format=config.file_format)
|
||||
print('-----------------------export model success-----------------------')
|
||||
|
||||
def parse_args():
|
||||
"""parse_args"""
|
||||
parser = argparse.ArgumentParser(description='Convert ckpt to designated format')
|
||||
parser.add_argument('--ckpt_file', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--file_name', type=str, default='faceattri', help='file name')
|
||||
parser.add_argument('--file_format', type=str, default='MINDIR', choices=['MINDIR', 'AIR'], help='file format')
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
if __name__ == "__main__":
|
||||
args_1 = parse_args()
|
||||
|
||||
args_1.dst_h = config.dst_h
|
||||
args_1.dst_w = config.dst_w
|
||||
args_1.attri_num = config.attri_num
|
||||
args_1.classes = config.classes
|
||||
args_1.flat_dim = config.flat_dim
|
||||
args_1.fc_dim = config.fc_dim
|
||||
|
||||
main(args_1)
|
||||
run_export()
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -14,19 +14,12 @@
|
|||
# ============================================================================
|
||||
"""preprocess"""
|
||||
import os
|
||||
import argparse
|
||||
from src.config import config
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.py_transforms as F
|
||||
import mindspore.dataset.transforms.py_transforms as F2
|
||||
|
||||
def parse_args():
|
||||
"""parse_args"""
|
||||
parser = argparse.ArgumentParser(description='face attribute dataset to bin')
|
||||
parser.add_argument('--model_path', type=str, default='', help='mindir path referenced')
|
||||
parser.add_argument('--mindrecord_path', type=str, default='', help='mindir file path')
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
from model_utils.config import config
|
||||
|
||||
def eval_data_generator(args):
|
||||
'''Build eval dataloader.'''
|
||||
|
@ -52,15 +45,7 @@ def eval_data_generator(args):
|
|||
return de_dataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
args_1 = parse_args()
|
||||
args_1.dst_h = config.dst_h
|
||||
args_1.dst_w = config.dst_w
|
||||
args_1.attri_num = config.attri_num
|
||||
args_1.classes = config.classes
|
||||
args_1.flat_dim = config.flat_dim
|
||||
args_1.fc_dim = config.fc_dim
|
||||
args_1.workers = config.workers
|
||||
ds = eval_data_generator(args_1)
|
||||
ds = eval_data_generator(config)
|
||||
cur_dir = os.getcwd()
|
||||
image_path = os.path.join(cur_dir, './data/image')
|
||||
if not os.path.isdir(image_path):
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Network config setting, will be used in train.py and eval.py"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
'per_batch_size': 128,
|
||||
'dst_h': 112,
|
||||
'dst_w': 112,
|
||||
'workers': 8,
|
||||
'attri_num': 3,
|
||||
'classes': '9,2,2',
|
||||
'backbone': 'resnet18',
|
||||
'loss_scale': 1024,
|
||||
'flat_dim': 512,
|
||||
'fc_dim': 256,
|
||||
'lr': 0.009,
|
||||
'lr_scale': 1,
|
||||
'lr_epochs': [20, 30, 50],
|
||||
'weight_decay': 0.0005,
|
||||
'momentum': 0.9,
|
||||
'max_epoch': 70,
|
||||
'warmup_epochs': 0,
|
||||
'log_interval': 10,
|
||||
'ckpt_path': '../../output',
|
||||
|
||||
# data_to_mindrecord parameter
|
||||
'eval_dataset_txt_file': 'Your_label_txt_file',
|
||||
'eval_mindrecord_file_name': 'Your_output_path/data_test.mindrecord',
|
||||
'train_dataset_txt_file': 'Your_label_txt_file',
|
||||
'train_mindrecord_file_name': 'Your_output_path/data_train.mindrecord',
|
||||
'train_append_dataset_txt_file': 'Your_label_txt_file',
|
||||
'train_append_mindrecord_file_name': 'Your_previous_output_path/data_train.mindrecord0'
|
||||
})
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from config import config
|
||||
from model_utils.config import config
|
||||
|
||||
dataset_txt_file = config.eval_dataset_txt_file
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from config import config
|
||||
from model_utils.config import config
|
||||
|
||||
dataset_txt_file = config.train_dataset_txt_file
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from config import config
|
||||
from model_utils.config import config
|
||||
|
||||
dataset_txt_file = config.train_append_dataset_txt_file
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
import os
|
||||
import time
|
||||
import datetime
|
||||
import argparse
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
|
@ -36,10 +35,10 @@ from src.FaceAttribute.loss_factory import get_loss
|
|||
from src.dataset_train import data_generator
|
||||
from src.lrsche_factory import warmup_step
|
||||
from src.logging import get_logger, AverageMeter
|
||||
from src.config import config
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
|
||||
class InternalCallbackParam(dict):
|
||||
|
@ -65,88 +64,99 @@ class BuildTrainNetwork(nn.Cell):
|
|||
return loss0
|
||||
|
||||
|
||||
def parse_args():
|
||||
'''Argument for Face Attributes.'''
|
||||
parser = argparse.ArgumentParser('Face Attributes')
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
|
||||
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
|
||||
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
arg, _ = parser.parse_known_args()
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
return arg
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
config.ckpt_path = os.path.join(config.output_path, config.ckpt_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
'''run train.'''
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=get_device_id())
|
||||
mindspore.set_seed(1)
|
||||
|
||||
# logger
|
||||
args = parse_args()
|
||||
|
||||
# init distributed
|
||||
if args.world_size != 1:
|
||||
if config.world_size != 1:
|
||||
init()
|
||||
args.local_rank = get_rank()
|
||||
args.world_size = get_group_size()
|
||||
|
||||
args.per_batch_size = config.per_batch_size
|
||||
args.dst_h = config.dst_h
|
||||
args.dst_w = config.dst_w
|
||||
args.workers = config.workers
|
||||
args.attri_num = config.attri_num
|
||||
args.classes = config.classes
|
||||
args.backbone = config.backbone
|
||||
args.loss_scale = config.loss_scale
|
||||
args.flat_dim = config.flat_dim
|
||||
args.fc_dim = config.fc_dim
|
||||
args.lr = config.lr
|
||||
args.lr_scale = config.lr_scale
|
||||
args.lr_epochs = config.lr_epochs
|
||||
args.weight_decay = config.weight_decay
|
||||
args.momentum = config.momentum
|
||||
args.max_epoch = config.max_epoch
|
||||
args.warmup_epochs = config.warmup_epochs
|
||||
args.log_interval = config.log_interval
|
||||
args.ckpt_path = config.ckpt_path
|
||||
|
||||
if args.world_size == 1:
|
||||
args.per_batch_size = 256
|
||||
else:
|
||||
args.lr = args.lr * 4.
|
||||
|
||||
if args.world_size != 1:
|
||||
config.local_rank = get_rank()
|
||||
config.world_size = get_group_size()
|
||||
config.lr = config.lr * 4.
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
else:
|
||||
config.per_batch_size = 256
|
||||
parallel_mode = ParallelMode.STAND_ALONE
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.world_size)
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=config.world_size)
|
||||
|
||||
# model and log save path
|
||||
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.local_rank)
|
||||
config.outputs_dir = os.path.join(config.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
config.logger = get_logger(config.outputs_dir, config.local_rank)
|
||||
loss_meter = AverageMeter('loss')
|
||||
|
||||
# dataloader
|
||||
args.logger.info('start create dataloader')
|
||||
de_dataloader, steps_per_epoch, num_classes = data_generator(args)
|
||||
args.steps_per_epoch = steps_per_epoch
|
||||
args.num_classes = num_classes
|
||||
args.logger.info('end create dataloader')
|
||||
args.logger.save_args(args)
|
||||
config.logger.info('start create dataloader')
|
||||
de_dataloader, steps_per_epoch, num_classes = data_generator(config)
|
||||
config.steps_per_epoch = steps_per_epoch
|
||||
config.num_classes = num_classes
|
||||
config.logger.info('end create dataloader')
|
||||
config.logger.save_args(config)
|
||||
|
||||
# backbone and loss
|
||||
args.logger.important_info('start create network')
|
||||
# backbone && loss && load pretrain model
|
||||
config.logger.important_info('start create network')
|
||||
create_network_start = time.time()
|
||||
network = get_resnet18(args)
|
||||
|
||||
network = get_resnet18(config)
|
||||
criterion = get_loss()
|
||||
|
||||
# load pretrain model
|
||||
if os.path.isfile(args.pretrained):
|
||||
param_dict = load_checkpoint(args.pretrained)
|
||||
if os.path.isfile(config.pretrained):
|
||||
param_dict = load_checkpoint(config.pretrained)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
|
@ -156,30 +166,24 @@ if __name__ == "__main__":
|
|||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load model {} success'.format(args.pretrained))
|
||||
config.logger.info('load model %s success', config.pretrained)
|
||||
|
||||
# optimizer and lr scheduler
|
||||
lr = warmup_step(args, gamma=0.1)
|
||||
opt = Momentum(params=network.trainable_params(),
|
||||
learning_rate=lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale)
|
||||
lr = warmup_step(config, gamma=0.1)
|
||||
opt = Momentum(params=network.trainable_params(), learning_rate=lr, momentum=config.momentum,
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
|
||||
train_net = BuildTrainNetwork(network, criterion)
|
||||
|
||||
# mixed precision training
|
||||
criterion.add_flags_recursive(fp32=True)
|
||||
|
||||
# package training process
|
||||
train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale)
|
||||
train_net = TrainOneStepCell(train_net, opt, sens=config.loss_scale)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
# checkpoint
|
||||
if args.local_rank == 0:
|
||||
ckpt_max_num = args.max_epoch
|
||||
train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
|
||||
ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
|
||||
if config.local_rank == 0:
|
||||
ckpt_max_num = config.max_epoch
|
||||
train_config = CheckpointConfig(save_checkpoint_steps=config.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
|
||||
ckpt_cb = ModelCheckpoint(config=train_config, directory=config.outputs_dir,
|
||||
prefix='{}'.format(config.local_rank))
|
||||
cb_params = InternalCallbackParam()
|
||||
cb_params.train_network = train_net
|
||||
cb_params.epoch_num = ckpt_max_num
|
||||
|
@ -201,39 +205,41 @@ if __name__ == "__main__":
|
|||
loss = train_net(data_tensor, gt_tensor)
|
||||
loss_meter.update(loss.asnumpy()[0])
|
||||
|
||||
# save ckpt
|
||||
if args.local_rank == 0:
|
||||
if config.local_rank == 0:
|
||||
cb_params.cur_step_num = i + 1
|
||||
cb_params.batch_num = i + 2
|
||||
ckpt_cb.step_end(run_context)
|
||||
|
||||
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
|
||||
if i % config.steps_per_epoch == 0 and config.local_rank == 0:
|
||||
cb_params.cur_epoch_num += 1
|
||||
|
||||
# save Log
|
||||
if i == 0:
|
||||
time_for_graph_compile = time.time() - create_network_start
|
||||
args.logger.important_info('{}, graph compile time={:.2f}s'.format(args.backbone, time_for_graph_compile))
|
||||
config.logger.important_info(
|
||||
'{}, graph compile time={:.2f}s'.format(config.backbone, time_for_graph_compile))
|
||||
|
||||
if i % args.log_interval == 0 and args.local_rank == 0:
|
||||
if i % config.log_interval == 0 and config.local_rank == 0:
|
||||
time_used = time.time() - t_end
|
||||
epoch = int(i / args.steps_per_epoch)
|
||||
fps = args.per_batch_size * (i - old_progress) * args.world_size / time_used
|
||||
args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
|
||||
epoch = int(i / config.steps_per_epoch)
|
||||
fps = config.per_batch_size * (i - old_progress) * config.world_size / time_used
|
||||
config.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
|
||||
|
||||
t_end = time.time()
|
||||
loss_meter.reset()
|
||||
old_progress = i
|
||||
|
||||
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
|
||||
if i % config.steps_per_epoch == 0 and config.local_rank == 0:
|
||||
epoch_time_used = time.time() - t_epoch
|
||||
epoch = int(i / args.steps_per_epoch)
|
||||
fps = args.per_batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
|
||||
args.logger.info('=================================================')
|
||||
args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
|
||||
args.logger.info('=================================================')
|
||||
epoch = int(i / config.steps_per_epoch)
|
||||
fps = config.per_batch_size * config.world_size * config.steps_per_epoch / epoch_time_used
|
||||
config.logger.info('=================================================')
|
||||
config.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
|
||||
config.logger.info('=================================================')
|
||||
t_epoch = time.time()
|
||||
|
||||
i += 1
|
||||
|
||||
args.logger.info('--------- trains out ---------')
|
||||
config.logger.info('--------- trains out ---------')
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_train()
|
||||
|
|
Loading…
Reference in New Issue