forked from mindspore-Ecosystem/mindspore
!19004 Xception can been used on ModelArts
Merge pull request !19004 from 郑彬/master
This commit is contained in:
commit
7bb22a0924
|
@ -86,11 +86,17 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
├─run_infer_310.sh # shell script for 310 inference
|
||||
└─run_eval_gpu.sh # launch evaluating with gpu platform
|
||||
├─src
|
||||
├─config.py # parameter configuration
|
||||
├─model_utils
|
||||
├─config.py # parsing parameter configuration file of "*.yaml"
|
||||
├─device_adapter.py # local or ModelArts training
|
||||
├─local_adapter.py # get related environment variables on local
|
||||
└─moxing_adapter.py # get related environment variables abd transfer data on ModelArts
|
||||
├─dataset.py # data preprocessing
|
||||
├─Xception.py # network definition
|
||||
├─loss.py # Customized CrossEntropy loss function
|
||||
└─lr_generator.py # learning rate generator
|
||||
├─default_config.yaml # parameter configuration
|
||||
├─mindspore_hub_conf.py # mindspore hub interface
|
||||
├─train.py # train net
|
||||
├─postprogress.py # post process for 310 inference
|
||||
├─export.py # export net
|
||||
|
@ -100,7 +106,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py.
|
||||
Parameters for both training and evaluation can be set in `default_config.yaml`.
|
||||
|
||||
- Config on ascend
|
||||
|
||||
|
@ -234,6 +240,33 @@ epoch: 2 step: 20018, loss is 5.179064
|
|||
epoch time: 5628609.779 ms, per step time: 281.177 ms
|
||||
```
|
||||
|
||||
### Training with 8 cards on ModelArts
|
||||
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/xception" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/xception/train.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/xception/default_config.yaml.
|
||||
# 1. Set ”enable_modelarts: True“
|
||||
# 2. Set “is_distributed: True”
|
||||
# 3. Set “modelarts_dataset_unzip_name: {folder_name}", if the data is uploaded in the form of zip package.
|
||||
# 4. Set “folder_name_under_zip_file: {path}”, (dateset path under the unzip folder, such as './ImageNet_Original/train')
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”enable_modelarts=True“
|
||||
# 2. Add “is_distributed: True”
|
||||
# 3. Add “modelarts_dataset_unzip_name: {folder_name}", if the data is uploaded in the form of zip package.
|
||||
# 4. Add “folder_name_under_zip_file: {path}”, (dateset path under the unzip folder, such as './ImageNet_Original/train')
|
||||
# (6) Upload the mindrecdrd dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path.
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of 8 cards..
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
## [Eval process](#contents)
|
||||
|
||||
### Usage
|
||||
|
@ -283,14 +316,68 @@ result: {'Loss': 1.7797744848789312, 'Top_1_Acc': 0.7985777243589743, 'Top_5_Acc
|
|||
result: {'Loss': 1.7846775874590903, 'Top_1_Acc': 0.798735595390525, 'Top_5_Acc': 0.9498439500640204}
|
||||
```
|
||||
|
||||
### Evaluating with single card on ModelArts
|
||||
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder 'xception' to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/xception" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/xception/eval.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/xception/default_config.yaml.
|
||||
# 1. Set ”enable_modelarts: True“
|
||||
# 2. Set “checkpoint_path: ./{path}/*.ckpt”('load_checkpoint_path' indicates the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.)
|
||||
# 3. Set “modelarts_dataset_unzip_name: {folder_name}", if the data is uploaded in the form of zip package.
|
||||
# 4. Set “folder_name_under_zip_file: {path}”, (dateset path under the unzip folder, such as './ImageNet_Original/validation_preprocess')
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”enable_modelarts: True“
|
||||
# 2. Add “checkpoint_path: ./{path}/*.ckpt”('load_checkpoint_path' indicates the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.)
|
||||
# 3. Add “modelarts_dataset_unzip_name: {folder_name}", if the data is uploaded in the form of zip package.
|
||||
# 4. Add “folder_name_under_zip_file: {path}”, (dateset path under the unzip folder, such as './ImageNet_Original/validation_preprocess')
|
||||
# (6) Upload the dataset(not mindrecord format) to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path.
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of a single card.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
## [Export process](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --device_target [DEVICE_TARGET] --file_format[EXPORT_FORMAT] --batch_size [BATCH_SIZE]
|
||||
```
|
||||
- Export on local
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --device_target [DEVICE_TARGET] --file_format[EXPORT_FORMAT] --batch_size [BATCH_SIZE]
|
||||
```
|
||||
|
||||
`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
|
||||
- Export on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start as follows)
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/xception" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/xception/export.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/xception/default_config.yaml.
|
||||
# 1. Set ”enable_modelarts: True“
|
||||
# 2. Set “ckpt_file: ./{path}/*.ckpt”('ckpt_file' indicates the path of the weight file to be exported relative to the file `export.py`, and the weight file must be included in the code directory.)
|
||||
# 3. Set ”file_name: xception“
|
||||
# 4. Set ”file_format:MINDIR“
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”enable_modelarts=True“
|
||||
# 2. Add “ckpt_file=./{path}/*.ckpt”('ckpt_file' indicates the path of the weight file to be exported relative to the file `export.py`, and the weight file must be included in the code directory.)
|
||||
# 3. Add ”file_name=xception“
|
||||
# 4. Add ”file_format=MINDIR“
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (This step is useless, but necessary.).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of a single card.
|
||||
# (10) Create your job.
|
||||
# You will see xception.mindir under {Output file path}.
|
||||
```
|
||||
|
||||
## [Inference process](#contents)
|
||||
|
||||
### Inference
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
modelarts_dataset_unzip_name: ''
|
||||
folder_name_under_zip_file: './'
|
||||
# ==============================================================================
|
||||
device_id: 0
|
||||
# train related
|
||||
is_distributed: False
|
||||
train_data_dir: ''
|
||||
is_fp32: False
|
||||
resume: ''
|
||||
# eval related
|
||||
checkpoint_path: ''
|
||||
test_data_dir: ''
|
||||
# export related
|
||||
batch_size: 1
|
||||
ckpt_file: ''
|
||||
width: 299
|
||||
height: 299
|
||||
file_name: "xception"
|
||||
file_format: "MINDIR"
|
||||
# config on GPU for Xception, imagenet2012.
|
||||
config_gpu:
|
||||
class_num: 1000
|
||||
batch_size: 64
|
||||
loss_scale: 1024
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
epoch_size: 250
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 1
|
||||
keep_checkpoint_max: 5
|
||||
save_checkpoint_path: "./gpu-ckpt"
|
||||
warmup_epochs: 1
|
||||
lr_decay_mode: "linear"
|
||||
use_label_smooth: True
|
||||
finish_epoch: 0
|
||||
label_smooth_factor: 0.1
|
||||
lr_init: 0.00004
|
||||
lr_max: 0.4
|
||||
lr_end: 0.00004
|
||||
# config on Ascend for Xception, imagenet2012.
|
||||
config_ascend:
|
||||
class_num: 1000
|
||||
batch_size: 128
|
||||
loss_scale: 1024
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001 # 1e-4
|
||||
epoch_size: 250
|
||||
save_checkpoint: True
|
||||
save_checkpoint_epochs: 1
|
||||
keep_checkpoint_max: 5
|
||||
save_checkpoint_path: "./"
|
||||
warmup_epochs: 1
|
||||
lr_decay_mode: "liner"
|
||||
use_label_smooth: True
|
||||
finish_epoch: 0
|
||||
label_smooth_factor: 0.1
|
||||
lr_init: 0.00004
|
||||
lr_max: 0.4
|
||||
lr_end: 0.00004
|
||||
---
|
||||
# Help description for ModelArts configuration
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend GPU or CPU(only for export), and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
# Help description for train configuration
|
||||
is_distributed: 'distributed training'
|
||||
train_data_dir: 'train dataset dir'
|
||||
is_fp32: 'fp32 training'
|
||||
resume: ''
|
||||
# Help description for eval configuration
|
||||
device_id: 'Device id'
|
||||
checkpoint_path: 'Checkpoint file path'
|
||||
test_data_dir: 'test Dataset dir'
|
||||
# Help description for export configuration
|
||||
batch_size: "batch size"
|
||||
ckpt_file: "xception ckpt file."
|
||||
width: "input width"
|
||||
height: "input height"
|
||||
file_name: "xception output file name."
|
||||
file_format: "file format"
|
||||
#
|
||||
---
|
||||
file_format: ["AIR", "MINDIR"]
|
||||
device_target: ["Ascend", "GPU", "CPU"]
|
|
@ -13,27 +13,81 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval Xception."""
|
||||
import argparse
|
||||
import time
|
||||
import os
|
||||
from mindspore import context, nn
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.Xception import xception
|
||||
from src.config import config_gpu, config_ascend
|
||||
from src.dataset import create_dataset
|
||||
from src.loss import CrossEntropySmooth
|
||||
from src.model_utils.config import config as args_opt, config_gpu, config_ascend
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--device_target', type=str, default='GPU', help='Device target')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, args_opt.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("Unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if args_opt.modelarts_dataset_unzip_name:
|
||||
zip_file_1 = os.path.join(args_opt.data_path, args_opt.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(args_opt.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
args_opt.test_data_dir = args_opt.data_path
|
||||
if args_opt.modelarts_dataset_unzip_name:
|
||||
args_opt.test_data_dir = os.path.join(args_opt.test_data_dir, args_opt.folder_name_under_zip_file)
|
||||
args_opt.checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), args_opt.checkpoint_path)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
if args_opt.device_target == "Ascend":
|
||||
config = config_ascend
|
||||
elif args_opt.device_target == "GPU":
|
||||
|
@ -45,8 +99,8 @@ if __name__ == '__main__':
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=config.batch_size, device_num=1, rank=0)
|
||||
step_size = dataset.get_dataset_size()
|
||||
dataset = create_dataset(args_opt.test_data_dir, do_train=False, batch_size=config.batch_size, device_num=1, rank=0)
|
||||
# step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
net = xception(class_num=config.class_num)
|
||||
|
@ -68,3 +122,7 @@ if __name__ == '__main__':
|
|||
# eval model
|
||||
res = model.eval(dataset, dataset_sink_mode=True)
|
||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_eval()
|
||||
|
|
|
@ -13,38 +13,34 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval Xception."""
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.Xception import xception
|
||||
from src.config import config_ascend, config_gpu
|
||||
from src.model_utils.config import config as args, config_gpu, config_ascend
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
parser = argparse.ArgumentParser(description="Image classification")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="xception ckpt file.")
|
||||
parser.add_argument("--width", type=int, default=299, help="input width")
|
||||
parser.add_argument("--height", type=int, default=299, help="input height")
|
||||
parser.add_argument("--file_name", type=str, default="xception", help="xception output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.device_target == "Ascend":
|
||||
config = config_ascend
|
||||
elif args.device_target == "GPU":
|
||||
config = config_gpu
|
||||
else:
|
||||
raise ValueError("Unsupported device_target.")
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
args.ckpt_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.ckpt_file)
|
||||
args.file_name = os.path.join(args.output_path, args.file_name)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# define net
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_export():
|
||||
'''export function'''
|
||||
if args.device_target == "Ascend":
|
||||
config = config_ascend
|
||||
elif args.device_target == "GPU":
|
||||
config = config_gpu
|
||||
else:
|
||||
raise ValueError("Unsupported device_target.")
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(device_id=args.device_id)
|
||||
net = xception(class_num=config.class_num)
|
||||
|
||||
# load checkpoint
|
||||
|
@ -54,3 +50,7 @@ if __name__ == "__main__":
|
|||
|
||||
image = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
|
||||
export(net, image, file_name=args.file_name, file_format=args.file_format)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_export()
|
||||
|
|
|
@ -40,8 +40,8 @@ do
|
|||
|
||||
env > env.log
|
||||
python ../train.py \
|
||||
--is_distributed \
|
||||
--is_distributed=True \
|
||||
--device_target=Ascend \
|
||||
--dataset_path=$DATA_DIR > log.txt 2>&1 &
|
||||
--train_data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
||||
|
|
|
@ -27,5 +27,5 @@ python ../eval.py \
|
|||
--device_target=Ascend \
|
||||
--device_id=$DEVICE_ID \
|
||||
--checkpoint_path=$PATH_CHECKPOINT \
|
||||
--dataset_path=$DATA_DIR > eval.log 2>&1 &
|
||||
--test_data_dir=$DATA_DIR > eval.log 2>&1 &
|
||||
cd ../
|
||||
|
|
|
@ -27,6 +27,5 @@ python ../eval.py \
|
|||
--device_target=GPU \
|
||||
--device_id=$DEVICE_ID \
|
||||
--checkpoint_path=$PATH_CHECKPOINT \
|
||||
--dataset_path=$DATA_DIR
|
||||
#--dataset_path=$DATA_DIR > eval_gpu.log 2>&1 &
|
||||
--test_data_dir=$DATA_DIR > eval_gpu.log 2>&1 &
|
||||
cd ../
|
||||
|
|
|
@ -24,5 +24,5 @@ echo "start training standalone on device $DEVICE_ID"
|
|||
|
||||
python ../train.py \
|
||||
--device_target=Ascend \
|
||||
--dataset_path=$DATA_DIR > log.txt 2>&1 &
|
||||
--train_data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
cd ../
|
||||
|
|
|
@ -35,9 +35,9 @@ then
|
|||
--output-filename gpu_fp16_dist_log \
|
||||
--merge-stderr-to-stdout \
|
||||
${PYTHON_EXEC} ../train.py \
|
||||
--is_distributed \
|
||||
--is_distributed=True \
|
||||
--device_target=GPU \
|
||||
--dataset_path=$DATA_DIR > gpu_fp16_dist_log.txt 2>&1 &
|
||||
--train_data_dir=$DATA_DIR > gpu_fp16_dist_log.txt 2>&1 &
|
||||
else
|
||||
PATH_TRAIN="./train_standalone_gpu_fp16"$(date "+%Y%m%d%H%M%S")
|
||||
if [ -d $PATH_TRAIN ];
|
||||
|
@ -50,6 +50,6 @@ else
|
|||
|
||||
${PYTHON_EXEC} ../train.py \
|
||||
--device_target=GPU \
|
||||
--dataset_path=$DATA_DIR > gpu_fp16_standard_log.txt 2>&1 &
|
||||
--train_data_dir=$DATA_DIR > gpu_fp16_standard_log.txt 2>&1 &
|
||||
fi
|
||||
cd ../
|
|
@ -35,10 +35,10 @@ then
|
|||
--output-filename gpu_fp32_dist_log \
|
||||
--merge-stderr-to-stdout \
|
||||
${PYTHON_EXEC} ../train.py \
|
||||
--is_distributed \
|
||||
--is_distributed=True \
|
||||
--device_target=GPU \
|
||||
--is_fp32 \
|
||||
--dataset_path=$DATA_DIR > gpu_fp32_dist_log.txt 2>&1 &
|
||||
--is_fp32=True \
|
||||
--train_data_dir=$DATA_DIR > gpu_fp32_dist_log.txt 2>&1 &
|
||||
else
|
||||
PATH_TRAIN="./train_standalone_gpu_fp32"$(date "+%Y%m%d%H%M%S")
|
||||
if [ -d $PATH_TRAIN ];
|
||||
|
@ -52,7 +52,7 @@ else
|
|||
#${PYTHON_EXEC} ../train.py \ --dataset_path=/gdata/ImageNet2012/train/
|
||||
${PYTHON_EXEC} ../train.py \
|
||||
--device_target=GPU \
|
||||
--is_fp32 \
|
||||
--dataset_path=$DATA_DIR > gpu_fp32_standard_log.txt 2>&1 &
|
||||
--is_fp32=True \
|
||||
--train_data_dir=$DATA_DIR > gpu_fp32_standard_log.txt 2>&1 &
|
||||
fi
|
||||
cd ../
|
|
@ -1,62 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
# config on GPU for Xception, imagenet2012.
|
||||
config_gpu = ed({
|
||||
"class_num": 1000,
|
||||
"batch_size": 64,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 250,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 5,
|
||||
"save_checkpoint_path": "./gpu-ckpt",
|
||||
"warmup_epochs": 1,
|
||||
"lr_decay_mode": "linear",
|
||||
"use_label_smooth": True,
|
||||
"finish_epoch": 0,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_init": 0.00004,
|
||||
"lr_max": 0.4,
|
||||
"lr_end": 0.00004
|
||||
})
|
||||
|
||||
# config on Ascend for Xception, imagenet2012.
|
||||
config_ascend = ed({
|
||||
"class_num": 1000,
|
||||
"batch_size": 128,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 250,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 5,
|
||||
"save_checkpoint_path": "./",
|
||||
"warmup_epochs": 1,
|
||||
"lr_decay_mode": "liner",
|
||||
"use_label_smooth": True,
|
||||
"finish_epoch": 0,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_init": 0.00004,
|
||||
"lr_max": 0.4,
|
||||
"lr_end": 0.00004
|
||||
})
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
|
||||
config = get_config()
|
||||
config_gpu = config.config_gpu
|
||||
config_ascend = config.config_ascend
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(config)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
# print("os.mknod({}) success".format(sync_lock))
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""train Xception."""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
|
@ -22,29 +22,80 @@ from mindspore.nn.optim.momentum import Momentum
|
|||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.lr_generator import get_lr
|
||||
from src.Xception import xception
|
||||
from src.config import config_gpu, config_ascend
|
||||
from src.dataset import create_dataset
|
||||
from src.loss import CrossEntropySmooth
|
||||
|
||||
from src.model_utils.config import config as args_opt, config_gpu, config_ascend
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification training')
|
||||
parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='run platform, (Default: Ascend)')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='dataset path')
|
||||
parser.add_argument("--is_fp32", action='store_true', default=False, help='fp32 training, add --is_fp32')
|
||||
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, args_opt.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("Unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("Unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("Cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if args_opt.modelarts_dataset_unzip_name:
|
||||
zip_file_1 = os.path.join(args_opt.data_path, args_opt.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(args_opt.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
args_opt.train_data_dir = args_opt.data_path
|
||||
if args_opt.modelarts_dataset_unzip_name:
|
||||
args_opt.train_data_dir = os.path.join(args_opt.train_data_dir, args_opt.folder_name_under_zip_file)
|
||||
config_gpu.save_checkpoint_path = os.path.join(args_opt.output_path, config_gpu.save_checkpoint_path)
|
||||
config_ascend.save_checkpoint_path = os.path.join(args_opt.output_path, config_ascend.save_checkpoint_path)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
if args_opt.device_target == "Ascend":
|
||||
config = config_ascend
|
||||
elif args_opt.device_target == "GPU":
|
||||
|
@ -54,20 +105,19 @@ if __name__ == '__main__':
|
|||
|
||||
# init distributed
|
||||
if args_opt.is_distributed:
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
context.set_context(device_id=get_device_id(), mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
||||
save_graphs=False)
|
||||
init()
|
||||
rank = get_rank()
|
||||
group_size = get_group_size()
|
||||
rank = get_rank_id()
|
||||
group_size = get_device_num()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
|
||||
else:
|
||||
rank = 0
|
||||
group_size = 1
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
|
||||
device_id = get_device_id()
|
||||
context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
||||
save_graphs=False)
|
||||
# define network
|
||||
net = xception(class_num=config.class_num)
|
||||
if args_opt.device_target == "Ascend":
|
||||
|
@ -79,7 +129,7 @@ if __name__ == '__main__':
|
|||
loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
|
||||
# define dataset
|
||||
dataset = create_dataset(args_opt.dataset_path, do_train=True, batch_size=config.batch_size,
|
||||
dataset = create_dataset(args_opt.train_data_dir, do_train=True, batch_size=config.batch_size,
|
||||
device_num=group_size, rank=rank)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
|
@ -137,3 +187,7 @@ if __name__ == '__main__':
|
|||
cb += [ckpt_cb]
|
||||
model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
print("train success")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
||||
|
|
Loading…
Reference in New Issue