merge dpn
This commit is contained in:
parent
61c96a444f
commit
151c523acf
|
@ -89,13 +89,13 @@ The DPN models use ImageNet-1K dataset to train and validate in this repository.
|
|||
To train the DPNs, run the shell script `scripts/train_standalone.sh` with the format below:
|
||||
|
||||
```shell
|
||||
sh scripts/train_standalone.sh [device_id] [dataset_dir] [ckpt_path_to_save] [eval_each_epoch] [pretrained_ckpt(optional)]
|
||||
sh scripts/train_standalone.sh [device_id] [train_data_dir] [ckpt_path_to_save] [eval_each_epoch] [pretrained_ckpt(optional)]
|
||||
```
|
||||
|
||||
To validate the DPNs, run the shell script `scripts/eval.sh` with the format below:
|
||||
|
||||
```shell
|
||||
sh scripts/eval.sh [device_id] [dataset_dir] [pretrained_ckpt]
|
||||
sh scripts/eval.sh [device_id] [eval_data_dir] [checkpoint_path]
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
@ -116,6 +116,11 @@ The structure of the files in this repository is shown below.
|
|||
│ ├─ dpn.py // dpns implementation
|
||||
│ ├─ imagenet_dataset.py // dataset processor and provider
|
||||
│ └─ lr_scheduler.py // dpn learning rate scheduler
|
||||
├── model_utils
|
||||
├──config.py // Parameter config
|
||||
├──moxing_adapter.py // modelarts device configuration
|
||||
├──device_adapter.py // Device Config
|
||||
├──local_adapter.py // local device config
|
||||
├─ eval.py // evaluation script
|
||||
├─ train.py // training script
|
||||
├─ export.py // export model
|
||||
|
@ -124,11 +129,11 @@ The structure of the files in this repository is shown below.
|
|||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in `src/config.py`
|
||||
Parameters for both training and evaluation and export can be set in `default_config.yaml`
|
||||
|
||||
- Configurations for DPN92 with ImageNet-1K dataset
|
||||
|
||||
```python
|
||||
```default_config.yaml
|
||||
# model config
|
||||
config.image_size = (224,224) # inpute image size
|
||||
config.num_classes = 1000 # dataset class number
|
||||
|
@ -174,7 +179,7 @@ config.keep_checkpoint_max = 3 # only keep the last keep_checkpoint
|
|||
Run `scripts/train_standalone.sh` to train the model standalone. The usage of the script is:
|
||||
|
||||
```shell
|
||||
sh scripts/train_standalone.sh [device_id] [dataset_dir] [ckpt_path_to_save] [eval_each_epoch] [pretrained_ckpt(optional)]
|
||||
sh scripts/train_standalone.sh [device_id] [train_data_dir] [ckpt_path_to_save] [eval_each_epoch] [pretrained_ckpt(optional)]
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the training procedure.
|
||||
|
@ -212,10 +217,16 @@ The model checkpoint will be saved into `[ckpt_path_to_save]`.
|
|||
|
||||
#### Running on Ascend
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
||||
Please follow the instructions in the link below:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.
|
||||
|
||||
Run `scripts/train_distributed.sh` to train the model distributed. The usage of the script is:
|
||||
|
||||
```text
|
||||
sh scripts/train_distributed.sh [rank_table] [dataset_dir] [ckpt_path_to_save] [rank_size] [eval_each_epoch] [pretrained_ckpt(optional)]
|
||||
sh scripts/train_distributed.sh [rank_table] [train_data_dir] [ckpt_path_to_save] [rank_size] [eval_each_epoch] [pretrained_ckpt(optional)]
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the training procedure.
|
||||
|
@ -243,7 +254,7 @@ The model checkpoint will be saved into `[ckpt_path_to_save]`.
|
|||
Run `scripts/eval.sh` to evaluate the model with one Ascend processor. The usage of the script is:
|
||||
|
||||
```text
|
||||
sh scripts/eval.sh [device_id] [dataset_dir] [pretrained_ckpt]
|
||||
sh scripts/eval.sh [device_id] [eval_data_dir] [checkpoint_path]
|
||||
```
|
||||
|
||||
For example, you can run the shell command below to launch the validation procedure.
|
||||
|
@ -259,6 +270,58 @@ Evaluation result: {'top_5_accuracy': 0.9449223751600512, 'top_1_accuracy': 0.79
|
|||
DPN evaluate success!
|
||||
```
|
||||
|
||||
- running on ModelArts
|
||||
- If you want to train the model on modelarts, you can refer to the [official guidance document] of modelarts (https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```python
|
||||
# Example of using distributed training dpn on modelarts :
|
||||
# Data set storage method
|
||||
|
||||
# ├── ImageNet_Original # dir
|
||||
# ├── train # train dir
|
||||
# ├── train_dataset # train_dataset dir
|
||||
# ├── train_predtrained # predtrained dir if exists
|
||||
# ├── eval # eval dir
|
||||
# ├── eval_dataset # eval dataset dir
|
||||
# ├── checkpoint # ckpt files dir
|
||||
|
||||
# (1) Choose either a (modify yaml file parameters) or b (modelArts create training job to modify parameters) 。
|
||||
# a. set "enable_modelarts=True" 。
|
||||
# set "is_distributed=1"
|
||||
# set "ckpt_path=/cache/train/outputs_imagenet/"
|
||||
# set "train_data_dir=/cache/data/train/train_dataset/"
|
||||
# set "pretrained=/cache/data/train/train_predtrained/pred file name" Without pre-training weights train_pretrained=""
|
||||
|
||||
# b. add "enable_modelarts=True" Parameters are on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (2) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (3) Set the code path on the modelarts interface "/path/dpn"。
|
||||
# (4) Set the model's startup file on the modelarts interface "train.py" 。
|
||||
# (5) Set the data path of the model on the modelarts interface ".../ImageNet_Original"(choices ImageNet_Original Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (6) start trainning the model。
|
||||
|
||||
# Example of using model inference on modelarts
|
||||
# (1) Place the trained model to the corresponding position of the bucket。
|
||||
# (2) chocie a or b。
|
||||
# a. set "enable_modelarts=True" 。
|
||||
# set "eval_data_dir=/cache/data/eval/eval_dataset/"
|
||||
# set "checkpoint_path=/cache/data/eval/checkpoint/"
|
||||
|
||||
# b. Add "enable_modelarts=True" parameter on the interface of modearts。
|
||||
# Set the parameters required by method a on the modelarts interface
|
||||
# Note: The path parameter does not need to be quoted
|
||||
|
||||
# (3) Set the path of the network configuration file "_config_path=/The path of config in default_config.yaml/"
|
||||
# (4) Set the code path on the modelarts interface "/path/dpn"。
|
||||
# (5) Set the model's startup file on the modelarts interface "eval.py" 。
|
||||
# (6) Set the data path of the model on the modelarts interface ".../ImageNet_Original"(choices ImageNet_Original Folder path) ,
|
||||
# The output path of the model "Output file path" and the log path of the model "Job log path" 。
|
||||
# (7) Start model inference。
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
# ======================================================================================
|
||||
# common options
|
||||
is_distributed: 0
|
||||
image_size: [224, 224]
|
||||
batch_size: 32
|
||||
num_parallel_workers: 4
|
||||
rank: 0
|
||||
group_size: 1
|
||||
num_classes: 1000
|
||||
label_smooth: False
|
||||
label_smooth_factor: 0.0
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
backbone: 'dpn92'
|
||||
is_save_on_master: True
|
||||
|
||||
# training config
|
||||
pretrained: ""
|
||||
ckpt_path: "./"
|
||||
eval_each_epoch: 0
|
||||
global_step: 0
|
||||
epoch_size: 180
|
||||
loss_scale_num: 1024
|
||||
momentum: 0.9
|
||||
weight_decay: 1e-4
|
||||
|
||||
# learning rate config
|
||||
lr_schedule: "warmup"
|
||||
lr_init: 0.01
|
||||
lr_max: 0.1
|
||||
factor: 0.1
|
||||
epoch_number_to_drop: [5, 15]
|
||||
warmup_epochs: 5
|
||||
|
||||
# dataset config
|
||||
train_data_dir: ""
|
||||
dataset: "imagenet-1K"
|
||||
keep_checkpoint_max: 3
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
eval_data_dir: ""
|
||||
checkpoint_path: ""
|
||||
|
||||
|
||||
# export options
|
||||
device_id: 0
|
||||
width: 224
|
||||
height: 224
|
||||
file_name: "dpn"
|
||||
file_format: "MINDIR"
|
||||
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of input data"
|
||||
output_pah: "The location of the output file"
|
||||
device_target: "device id of GPU or Ascend. (Default: None)"
|
||||
enable_profiling: "Whether enable profiling while training default: False"
|
||||
train_data_dir: "Imagenet train data dir"
|
||||
pretrained: "ckpt path to load"
|
||||
is_distributed: "if multi device"
|
||||
ckpt_path: "ckpt path to save"
|
||||
eval_each_epoch: "evaluate on each epoch"
|
||||
eval_data_dir: "eval data dir"
|
||||
checkpoint_path: "ckpt path to load"
|
||||
device_id: "device id"
|
||||
width: "input width"
|
||||
height: "input height"
|
||||
file_name: "dpn output file name"
|
||||
file_format: "choices [AIR, ONNX, MINDIR]"
|
|
@ -13,70 +13,53 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DPN model eval with MindSpore"""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.dpn import dpns
|
||||
from src.config import config
|
||||
from src.imagenet_dataset import classification_dataset
|
||||
from src.dpn import dpns
|
||||
from src.crossentropy import CrossEntropy
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
# set context
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
device_target=config.device_target, save_graphs=False, device_id=get_device_id())
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('dpn evaluating')
|
||||
# dataset related
|
||||
parser.add_argument('--data_dir', type=str, default='', help='eval data dir')
|
||||
# network related
|
||||
parser.add_argument('--pretrained', type=str, default='', help='ckpt path to load')
|
||||
args, _ = parser.parse_known_args()
|
||||
args.image_size = config.image_size
|
||||
args.num_classes = config.num_classes
|
||||
args.batch_size = config.batch_size
|
||||
args.num_parallel_workers = config.num_parallel_workers
|
||||
args.backbone = config.backbone
|
||||
args.loss_scale_num = config.loss_scale_num
|
||||
args.rank = config.rank
|
||||
args.group_size = config.group_size
|
||||
args.dataset = config.dataset
|
||||
return args
|
||||
|
||||
|
||||
def dpn_evaluate(args):
|
||||
@moxing_wrapper(pre_process=None)
|
||||
def dpn_evaluate():
|
||||
# create evaluate dataset
|
||||
eval_path = os.path.join(args.data_dir, 'val')
|
||||
eval_dataset = classification_dataset(eval_path,
|
||||
image_size=args.image_size,
|
||||
num_parallel_workers=args.num_parallel_workers,
|
||||
per_batch_size=args.batch_size,
|
||||
# eval_path = os.path.join(config.eval_data_dir, 'val')
|
||||
eval_dataset = classification_dataset(config.eval_data_dir,
|
||||
image_size=config.image_size,
|
||||
num_parallel_workers=config.num_parallel_workers,
|
||||
per_batch_size=config.batch_size,
|
||||
max_epoch=1,
|
||||
rank=args.rank,
|
||||
rank=config.rank,
|
||||
shuffle=False,
|
||||
group_size=args.group_size,
|
||||
group_size=config.group_size,
|
||||
mode='eval')
|
||||
|
||||
# create network
|
||||
net = dpns[args.backbone](num_classes=args.num_classes)
|
||||
net = dpns[config.backbone](num_classes=config.num_classes)
|
||||
# load checkpoint
|
||||
load_param_into_net(net, load_checkpoint(args.pretrained))
|
||||
print("load checkpoint from [{}].".format(args.pretrained))
|
||||
load_param_into_net(net, load_checkpoint(config.checkpoint_path))
|
||||
print("load checkpoint from [{}].".format(config.checkpoint_path))
|
||||
# loss
|
||||
if args.dataset == "imagenet-1K":
|
||||
if config.dataset == "imagenet-1K":
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
else:
|
||||
if not args.label_smooth:
|
||||
args.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
|
||||
if not config.label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
|
||||
|
||||
# create model
|
||||
model = Model(net, amp_level="O2", keep_batchnorm_fp32=False, loss_fn=loss,
|
||||
|
@ -87,5 +70,5 @@ def dpn_evaluate(args):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dpn_evaluate(parse_args())
|
||||
dpn_evaluate()
|
||||
print('DPN evaluate success!')
|
||||
|
|
|
@ -12,30 +12,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Export DPN"""
|
||||
import argparse
|
||||
"""Export DPN
|
||||
suggest run as python export.py --file_name [filename] --file_format [file format] --checkpoint_path [ckpt path]
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.dpn import dpns
|
||||
from src.config import config
|
||||
from src.model_utils.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description="export dpn")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="dpn ckpt file.")
|
||||
parser.add_argument("--width", type=int, default=224, help="input width")
|
||||
parser.add_argument("--height", type=int, default=224, help="input height")
|
||||
parser.add_argument("--file_name", type=str, default="dpn", help="dpn output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"],
|
||||
default="MINDIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=config.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# define net
|
||||
|
@ -44,9 +33,9 @@ if __name__ == "__main__":
|
|||
net = dpns[backbone](num_classes=num_classes)
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
image = Tensor(np.zeros([config.batch_size, 3, args.height, args.width], np.float32))
|
||||
export(net, image, file_name=args.file_name, file_format=args.file_format)
|
||||
image = Tensor(np.zeros([config.batch_size, 3, config.height, config.width], np.float32))
|
||||
export(net, image, file_name=config.file_name, file_format=config.file_format)
|
||||
|
|
|
@ -18,5 +18,5 @@ DATA_DIR=$2
|
|||
PATH_CHECKPOINT=$3
|
||||
|
||||
python eval.py \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR > eval_log.txt 2>&1 &
|
||||
--checkpoint_path=$PATH_CHECKPOINT \
|
||||
--eval_data_dir=$DATA_DIR > eval_log.txt 2>&1 &
|
||||
|
|
|
@ -36,6 +36,7 @@ do
|
|||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./*yaml ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
|
||||
|
@ -47,12 +48,12 @@ do
|
|||
--is_distributed=1 \
|
||||
--ckpt_path=$SAVE_PATH \
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH\
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
--train_data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
echo "python train.py \
|
||||
--is_distributed=1 \
|
||||
--ckpt_path=$SAVE_PATH \
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH\
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &"
|
||||
--train_data_dir=$DATA_DIR > log.txt 2>&1 &"
|
||||
fi
|
||||
|
||||
if [ $# == 6 ]
|
||||
|
@ -62,7 +63,7 @@ do
|
|||
--eval_each_epoch=$EVAL_EACH_EPOCH\
|
||||
--ckpt_path=$SAVE_PATH \
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
--train_data_dir=$DATA_DIR > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
cd ../
|
||||
|
|
|
@ -30,12 +30,12 @@ then
|
|||
--is_distributed=0 \
|
||||
--ckpt_path=$SAVE_CKPT_PATH\
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH\
|
||||
--data_dir=$DATA_DIR > train_log.txt 2>&1 &
|
||||
--train_data_dir=$DATA_DIR > train_log.txt 2>&1 &
|
||||
echo " python train.py \
|
||||
--is_distributed=0 \
|
||||
--ckpt_path=$SAVE_CKPT_PATH\
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH\
|
||||
--data_dir=$DATA_DIR > train_log.txt 2>&1 &"
|
||||
--train_data_dir=$DATA_DIR > train_log.txt 2>&1 &"
|
||||
fi
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
|
@ -43,6 +43,6 @@ then
|
|||
--is_distributed=0 \
|
||||
--ckpt_path=$SAVE_CKPT_PATH\
|
||||
--pretrained=$PATH_CHECKPOINT \
|
||||
--data_dir=$DATA_DIR\
|
||||
--train_data_dir=$DATA_DIR\
|
||||
--eval_each_epoch=$EVAL_EACH_EPOCH > train_log.txt 2>&1 &
|
||||
fi
|
|
@ -1,56 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# config for dpn,imagenet-1K
|
||||
config = edict()
|
||||
|
||||
# model config
|
||||
config.image_size = (224, 224) # inpute image size
|
||||
config.num_classes = 1000 # dataset class number
|
||||
config.backbone = 'dpn92' # backbone network
|
||||
config.is_save_on_master = True
|
||||
|
||||
# parallel config
|
||||
config.num_parallel_workers = 4 # number of workers to read the data
|
||||
config.rank = 0 # local rank of distributed
|
||||
config.group_size = 1 # group size of distributed
|
||||
|
||||
# training config
|
||||
config.batch_size = 32 # batch_size
|
||||
config.global_step = 0 # start step of learning rate
|
||||
config.epoch_size = 180 # epoch_size
|
||||
config.loss_scale_num = 1024 # loss scale
|
||||
# optimizer config
|
||||
config.momentum = 0.9 # momentum (SGD)
|
||||
config.weight_decay = 1e-4 # weight_decay (SGD)
|
||||
# learning rate config
|
||||
config.lr_schedule = 'warmup' # learning rate schedule
|
||||
config.lr_init = 0.01 # init learning rate
|
||||
config.lr_max = 0.1 # max learning rate
|
||||
config.factor = 0.1 # factor of lr to drop
|
||||
config.epoch_number_to_drop = [5, 15] # learing rate will drop after these epochs
|
||||
config.warmup_epochs = 5 # warmup epochs in learning rate schedule
|
||||
|
||||
# dataset config
|
||||
config.dataset = "imagenet-1K" # dataset
|
||||
config.label_smooth = False # label_smooth
|
||||
config.label_smooth_factor = 0.0 # label_smooth_factor
|
||||
|
||||
# parameter save config
|
||||
config.keep_checkpoint_max = 3 # only keep the last keep_checkpoint_max checkpoint
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
|
||||
_config_path = '../../default_config.yaml'
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path='default_config.yaml'):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml
|
||||
|
||||
Args:
|
||||
parser: Parent parser
|
||||
cfg: Base configuration
|
||||
helper: Helper description
|
||||
cfg_path: Path to the default yaml config
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='[REPLACE THIS at config.py]',
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else 'Please reference to {}'.format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument('--' + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument('--' + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError('At most 3 docs (config description for help, choices) are supported in config yaml')
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError('Failed to parse yaml')
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments
|
||||
|
||||
Args:
|
||||
args: command line arguments
|
||||
cfg: Base configuration
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='default name', add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument('--config_path', type=str, default=os.path.join(current_dir, _config_path),
|
||||
help='Config file path')
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
'get_device_id', 'get_device_num', 'get_job_id', 'get_rank_id'
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return 'Local Job'
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# you may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0#
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANT IES OR CONITTONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ====================================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
|
||||
_global_syn_count = 0
|
||||
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local
|
||||
Uploca data from local directory to remote obs in contrast
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_syn_count
|
||||
sync_lock = '/tmp/copy_sync.lock' + str(_global_syn_count)
|
||||
_global_syn_count += 1
|
||||
|
||||
# Each server contains 8 devices as most
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print('from path: ', from_path)
|
||||
print('to path: ', to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print('===finished data synchronization===')
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print('===save flag===')
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
print('Finish sync data from {} to {}'.format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print('Dataset downloaded: ', os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
if not os.path.exists(config.load_path):
|
||||
# os.makedirs(config.load_path)
|
||||
print('=' * 20 + 'makedirs')
|
||||
if os.path.isdir(config.load_path):
|
||||
print('=' * 20 + 'makedirs success')
|
||||
else:
|
||||
print('=' * 20 + 'makedirs fail')
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print('Preload downloaded: ', os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print('Workspace downloaded: ', os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print('Start to copy output directory')
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -14,8 +14,7 @@
|
|||
# ============================================================================
|
||||
"""DPN model train with MindSpore"""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from ast import literal_eval
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import SGD
|
||||
|
@ -24,145 +23,112 @@ from mindspore.train.model import Model
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.communication.management import init, get_group_size
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.imagenet_dataset import classification_dataset
|
||||
from src.dpn import dpns
|
||||
from src.config import config
|
||||
from src.lr_scheduler import get_lr_drop, get_lr_warmup
|
||||
from src.crossentropy import CrossEntropy
|
||||
from src.callbacks import SaveCallback
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""parameters"""
|
||||
parser = argparse.ArgumentParser('dpn training')
|
||||
|
||||
# dataset related
|
||||
parser.add_argument('--data_dir', type=str, default='', help='Imagenet data dir')
|
||||
# network related
|
||||
parser.add_argument('--pretrained', default='', type=str, help='ckpt path to load')
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
|
||||
parser.add_argument('--ckpt_path', type=str, default='', help='ckpt path to save')
|
||||
parser.add_argument('--eval_each_epoch', type=int, default=0, help='evaluate on each epoch')
|
||||
args, _ = parser.parse_known_args()
|
||||
args.image_size = config.image_size
|
||||
args.num_classes = config.num_classes
|
||||
args.lr_init = config.lr_init
|
||||
args.lr_max = config.lr_max
|
||||
args.factor = config.factor
|
||||
args.global_step = config.global_step
|
||||
args.epoch_number_to_drop = config.epoch_number_to_drop
|
||||
args.epoch_size = config.epoch_size
|
||||
args.warmup_epochs = config.warmup_epochs
|
||||
args.weight_decay = config.weight_decay
|
||||
args.momentum = config.momentum
|
||||
args.batch_size = config.batch_size
|
||||
args.num_parallel_workers = config.num_parallel_workers
|
||||
args.backbone = config.backbone
|
||||
args.loss_scale_num = config.loss_scale_num
|
||||
args.is_save_on_master = config.is_save_on_master
|
||||
args.rank = config.rank
|
||||
args.group_size = config.group_size
|
||||
args.dataset = config.dataset
|
||||
args.label_smooth = config.label_smooth
|
||||
args.label_smooth_factor = config.label_smooth_factor
|
||||
args.keep_checkpoint_max = config.keep_checkpoint_max
|
||||
args.lr_schedule = config.lr_schedule
|
||||
return args
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
||||
def dpn_train(args):
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def dpn_train():
|
||||
# init context
|
||||
device_id = get_device_id()
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
device_target=config.device_target, save_graphs=False, device_id=device_id)
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if config.is_distributed:
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
context.set_auto_parallel_context(device_num=args.group_size, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
config.rank = get_rank_id()
|
||||
config.group_size = get_group_size()
|
||||
config.device_num = get_device_num()
|
||||
context.set_auto_parallel_context(device_num=config.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
|
||||
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
config.rank_save_ckpt_flag = 0
|
||||
if config.is_save_on_master:
|
||||
if config.rank == 0:
|
||||
config.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
config.rank_save_ckpt_flag = 1
|
||||
# create dataset
|
||||
args.train_dir = os.path.join(args.data_dir, 'train')
|
||||
args.eval_dir = os.path.join(args.data_dir, 'val')
|
||||
train_dataset = classification_dataset(args.train_dir,
|
||||
image_size=args.image_size,
|
||||
per_batch_size=args.batch_size,
|
||||
train_dataset = classification_dataset(config.train_data_dir,
|
||||
image_size=config.image_size,
|
||||
per_batch_size=config.batch_size,
|
||||
max_epoch=1,
|
||||
num_parallel_workers=args.num_parallel_workers,
|
||||
num_parallel_workers=config.num_parallel_workers,
|
||||
shuffle=True,
|
||||
rank=args.rank,
|
||||
group_size=args.group_size)
|
||||
if args.eval_each_epoch:
|
||||
rank=config.rank,
|
||||
group_size=config.group_size)
|
||||
if config.eval_each_epoch:
|
||||
print("create eval_dataset")
|
||||
eval_dataset = classification_dataset(args.eval_dir,
|
||||
image_size=args.image_size,
|
||||
per_batch_size=args.batch_size,
|
||||
eval_dataset = classification_dataset(config.eval_data_dir,
|
||||
image_size=config.image_size,
|
||||
per_batch_size=config.batch_size,
|
||||
max_epoch=1,
|
||||
num_parallel_workers=args.num_parallel_workers,
|
||||
num_parallel_workers=config.num_parallel_workers,
|
||||
shuffle=False,
|
||||
rank=args.rank,
|
||||
group_size=args.group_size,
|
||||
rank=config.rank,
|
||||
group_size=config.group_size,
|
||||
mode='eval')
|
||||
train_step_size = train_dataset.get_dataset_size()
|
||||
|
||||
# choose net
|
||||
net = dpns[args.backbone](num_classes=args.num_classes)
|
||||
net = dpns[config.backbone](num_classes=config.num_classes)
|
||||
|
||||
# load checkpoint
|
||||
if os.path.isfile(args.pretrained):
|
||||
if os.path.isfile(config.pretrained):
|
||||
print("load ckpt")
|
||||
load_param_into_net(net, load_checkpoint(args.pretrained))
|
||||
load_param_into_net(net, load_checkpoint(config.pretrained))
|
||||
# learing rate schedule
|
||||
if args.lr_schedule == 'drop':
|
||||
if config.lr_schedule == 'drop':
|
||||
print("lr_schedule:drop")
|
||||
lr = Tensor(get_lr_drop(global_step=args.global_step,
|
||||
total_epochs=args.epoch_size,
|
||||
lr = Tensor(get_lr_drop(global_step=config.global_step,
|
||||
total_epochs=config.epoch_size,
|
||||
steps_per_epoch=train_step_size,
|
||||
lr_init=args.lr_init,
|
||||
factor=args.factor))
|
||||
elif args.lr_schedule == 'warmup':
|
||||
lr_init=config.lr_init,
|
||||
factor=config.factor))
|
||||
elif config.lr_schedule == 'warmup':
|
||||
print("lr_schedule:warmup")
|
||||
lr = Tensor(get_lr_warmup(global_step=args.global_step,
|
||||
total_epochs=args.epoch_size,
|
||||
lr = Tensor(get_lr_warmup(global_step=config.global_step,
|
||||
total_epochs=config.epoch_size,
|
||||
steps_per_epoch=train_step_size,
|
||||
lr_init=args.lr_init,
|
||||
lr_max=args.lr_max,
|
||||
warmup_epochs=args.warmup_epochs))
|
||||
lr_init=config.lr_init,
|
||||
lr_max=config.lr_max,
|
||||
warmup_epochs=config.warmup_epochs))
|
||||
|
||||
# optimizer
|
||||
config.weight_decay = literal_eval(config.weight_decay)
|
||||
opt = SGD(net.trainable_params(),
|
||||
lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale_num)
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay,
|
||||
loss_scale=config.loss_scale_num)
|
||||
# loss scale
|
||||
loss_scale = FixedLossScaleManager(args.loss_scale_num, False)
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale_num, False)
|
||||
# loss function
|
||||
if args.dataset == "imagenet-1K":
|
||||
if config.dataset == "imagenet-1K":
|
||||
print("Use SoftmaxCrossEntropyWithLogits")
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
else:
|
||||
if not args.label_smooth:
|
||||
args.label_smooth_factor = 0.0
|
||||
if not config.label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
print("Use Label_smooth CrossEntropy")
|
||||
loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
|
||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.num_classes)
|
||||
# create model
|
||||
model = Model(net, amp_level="O2",
|
||||
keep_batchnorm_fp32=False,
|
||||
|
@ -175,19 +141,19 @@ def dpn_train(args):
|
|||
loss_cb = LossMonitor()
|
||||
time_cb = TimeMonitor(data_size=train_step_size)
|
||||
cb = [loss_cb, time_cb]
|
||||
if args.rank_save_ckpt_flag:
|
||||
if args.eval_each_epoch:
|
||||
save_cb = SaveCallback(model, eval_dataset, args.ckpt_path)
|
||||
if config.rank_save_ckpt_flag:
|
||||
if config.eval_each_epoch:
|
||||
save_cb = SaveCallback(model, eval_dataset, config.ckpt_path)
|
||||
cb += [save_cb]
|
||||
else:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=train_step_size,
|
||||
keep_checkpoint_max=args.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="dpn", directory=args.ckpt_path, config=config_ck)
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="dpn", directory=config.ckpt_path, config=config_ck)
|
||||
cb.append(ckpoint_cb)
|
||||
# train model
|
||||
model.train(args.epoch_size, train_dataset, callbacks=cb)
|
||||
model.train(config.epoch_size, train_dataset, callbacks=cb)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dpn_train(parse_args())
|
||||
dpn_train()
|
||||
print('DPN training success!')
|
||||
|
|
Loading…
Reference in New Issue