forked from mindspore-Ecosystem/mindspore
!17951 modify model_zoo squeezenet
From: @Somnus2020 Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
77e562db4f
|
@ -100,37 +100,6 @@ After installing MindSpore via the official website, you can start training and
|
||||||
sh scripts/run_eval_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]
|
sh scripts/run_eval_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||||
```
|
```
|
||||||
|
|
||||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# run distributed training on modelarts example
|
|
||||||
# (1) First, Perform a or b.
|
|
||||||
# a. Set "enable_modelarts=True" on yaml file.
|
|
||||||
# Set other parameters on yaml file you need.
|
|
||||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
|
||||||
# Add other parameters on the website UI interface.
|
|
||||||
# (2) Set the Dataset directory in config file.
|
|
||||||
# (3) Set the code directory to "/path/squeezenet" on the website UI interface.
|
|
||||||
# (4) Set the startup file to "train.py" on the website UI interface.
|
|
||||||
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
|
||||||
# (6) Create your job.
|
|
||||||
|
|
||||||
# run evaluation on modelarts example
|
|
||||||
# (1) Copy or upload your trained model to S3 bucket.
|
|
||||||
# (2) Perform a or b.
|
|
||||||
# a. Set "enable_modelarts=True" on yaml file.
|
|
||||||
# Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file.
|
|
||||||
# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file.
|
|
||||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
|
||||||
# Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
|
||||||
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
|
|
||||||
# (3) Set the Dataset directory in config file.
|
|
||||||
# (4) Set the code directory to "/path/squeezenet" on the website UI interface.
|
|
||||||
# (5) Set the startup file to "eval.py" on the website UI interface.
|
|
||||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
|
||||||
# (7) Create your job.
|
|
||||||
```
|
|
||||||
|
|
||||||
# [Script Description](#contents)
|
# [Script Description](#contents)
|
||||||
|
|
||||||
## [Script and Sample Code](#contents)
|
## [Script and Sample Code](#contents)
|
||||||
|
@ -147,19 +116,11 @@ After installing MindSpore via the official website, you can start training and
|
||||||
├── run_eval.sh # launch ascend evaluation
|
├── run_eval.sh # launch ascend evaluation
|
||||||
└── run_eval_gpu.sh # launch gpu evaluation
|
└── run_eval_gpu.sh # launch gpu evaluation
|
||||||
├── src
|
├── src
|
||||||
|
├── config.py # parameter configuration
|
||||||
├── dataset.py # data preprocessing
|
├── dataset.py # data preprocessing
|
||||||
├── CrossEntropySmooth.py # loss definition for ImageNet dataset
|
├── CrossEntropySmooth.py # loss definition for ImageNet dataset
|
||||||
├── lr_generator.py # generate learning rate for each step
|
├── lr_generator.py # generate learning rate for each step
|
||||||
└── squeezenet.py # squeezenet architecture, including squeezenet and squeezenet_residual
|
└── squeezenet.py # squeezenet architecture, including squeezenet and squeezenet_residual
|
||||||
├── model_utils
|
|
||||||
│ ├── device_adapter.py # device adapter
|
|
||||||
│ ├── local_adapter.py # local adapter
|
|
||||||
│ ├── moxing_adapter.py # moxing adapter
|
|
||||||
│ ├── config.py # parameter analysis
|
|
||||||
├── squeezenet_cifar10_config.yaml # parameter configuration
|
|
||||||
├── squeezenet_imagenet_config.yaml # parameter configuration
|
|
||||||
├── squeezenet_residual_cifar10_config.yaml # parameter configuration
|
|
||||||
├── squeezenet_residual_imagenet_config.yaml # parameter configuration
|
|
||||||
├── train.py # train net
|
├── train.py # train net
|
||||||
├── eval.py # eval net
|
├── eval.py # eval net
|
||||||
└── export.py # export checkpoint files into geir/onnx
|
└── export.py # export checkpoint files into geir/onnx
|
||||||
|
|
|
@ -14,34 +14,44 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""eval squeezenet."""
|
"""eval squeezenet."""
|
||||||
import os
|
import os
|
||||||
|
import argparse
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from model_utils.config import config
|
|
||||||
from model_utils.moxing_adapter import moxing_wrapper
|
|
||||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
|
parser.add_argument('--net', type=str, default='squeezenet', choices=['squeezenet', 'squeezenet_residual'],
|
||||||
|
help='Model.')
|
||||||
|
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'], help='Dataset.')
|
||||||
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
if config.net_name == "squeezenet":
|
if args_opt.net == "squeezenet":
|
||||||
from src.squeezenet import SqueezeNet as squeezenet
|
from src.squeezenet import SqueezeNet as squeezenet
|
||||||
if config.dataset == "cifar10":
|
if args_opt.dataset == "cifar10":
|
||||||
|
from src.config import config1 as config
|
||||||
from src.dataset import create_dataset_cifar as create_dataset
|
from src.dataset import create_dataset_cifar as create_dataset
|
||||||
else:
|
else:
|
||||||
|
from src.config import config2 as config
|
||||||
from src.dataset import create_dataset_imagenet as create_dataset
|
from src.dataset import create_dataset_imagenet as create_dataset
|
||||||
else:
|
else:
|
||||||
from src.squeezenet import SqueezeNet_Residual as squeezenet
|
from src.squeezenet import SqueezeNet_Residual as squeezenet
|
||||||
if config.dataset == "cifar10":
|
if args_opt.dataset == "cifar10":
|
||||||
|
from src.config import config3 as config
|
||||||
from src.dataset import create_dataset_cifar as create_dataset
|
from src.dataset import create_dataset_cifar as create_dataset
|
||||||
else:
|
else:
|
||||||
|
from src.config import config4 as config
|
||||||
from src.dataset import create_dataset_imagenet as create_dataset
|
from src.dataset import create_dataset_imagenet as create_dataset
|
||||||
|
|
||||||
@moxing_wrapper()
|
if __name__ == '__main__':
|
||||||
def eval_net():
|
target = args_opt.device_target
|
||||||
"""eval net """
|
|
||||||
target = config.device_target
|
|
||||||
|
|
||||||
# init context
|
# init context
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
@ -50,21 +60,22 @@ def eval_net():
|
||||||
device_id=device_id)
|
device_id=device_id)
|
||||||
|
|
||||||
# create dataset
|
# create dataset
|
||||||
dataset = create_dataset(dataset_path=config.data_path,
|
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||||
do_train=False,
|
do_train=False,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
target=target)
|
target=target)
|
||||||
|
step_size = dataset.get_dataset_size()
|
||||||
|
|
||||||
# define net
|
# define net
|
||||||
net = squeezenet(num_classes=config.class_num)
|
net = squeezenet(num_classes=config.class_num)
|
||||||
|
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
param_dict = load_checkpoint(config.checkpoint_file_path)
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||||
load_param_into_net(net, param_dict)
|
load_param_into_net(net, param_dict)
|
||||||
net.set_train(False)
|
net.set_train(False)
|
||||||
|
|
||||||
# define loss
|
# define loss
|
||||||
if config.dataset == "imagenet":
|
if args_opt.dataset == "imagenet":
|
||||||
if not config.use_label_smooth:
|
if not config.use_label_smooth:
|
||||||
config.label_smooth_factor = 0.0
|
config.label_smooth_factor = 0.0
|
||||||
loss = CrossEntropySmooth(sparse=True,
|
loss = CrossEntropySmooth(sparse=True,
|
||||||
|
@ -81,7 +92,4 @@ def eval_net():
|
||||||
|
|
||||||
# eval model
|
# eval model
|
||||||
res = model.eval(dataset)
|
res = model.eval(dataset)
|
||||||
print("result:", res, "ckpt=", config.checkpoint_file_path)
|
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
eval_net()
|
|
||||||
|
|
|
@ -17,29 +17,36 @@
|
||||||
python export.py --net squeezenet --dataset cifar10 --checkpoint_path squeezenet_cifar10-120_1562.ckpt
|
python export.py --net squeezenet --dataset cifar10 --checkpoint_path squeezenet_cifar10-120_1562.ckpt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||||
from model_utils.config import config
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if config.net_name == "squeezenet":
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
|
parser.add_argument('--net', type=str, default='squeezenet', choices=['squeezenet', 'squeezenet_residual'],
|
||||||
|
help='Model.')
|
||||||
|
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'], help='Dataset.')
|
||||||
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
if args_opt.net == "squeezenet":
|
||||||
from src.squeezenet import SqueezeNet as squeezenet
|
from src.squeezenet import SqueezeNet as squeezenet
|
||||||
else:
|
else:
|
||||||
from src.squeezenet import SqueezeNet_Residual as squeezenet
|
from src.squeezenet import SqueezeNet_Residual as squeezenet
|
||||||
if config.dataset == "cifar10":
|
if args_opt.dataset == "cifar10":
|
||||||
num_classes = 10
|
num_classes = 10
|
||||||
else:
|
else:
|
||||||
num_classes = 1000
|
num_classes = 1000
|
||||||
|
|
||||||
onnx_filename = config.net_name + '_' + config.dataset
|
onnx_filename = args_opt.net + '_' + args_opt.dataset
|
||||||
air_filename = config.net_name + '_' + config.dataset
|
air_filename = args_opt.net + '_' + args_opt.dataset
|
||||||
|
|
||||||
net = squeezenet(num_classes=num_classes)
|
net = squeezenet(num_classes=num_classes)
|
||||||
|
|
||||||
assert config.checkpoint_file_path is not None, "checkpoint_file_path is None."
|
assert args_opt.checkpoint_path is not None, "checkpoint_path is None."
|
||||||
|
|
||||||
param_dict = load_checkpoint(config.checkpoint_file_path)
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||||
load_param_into_net(net, param_dict)
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
input_arr = Tensor(np.zeros([1, 3, 227, 227], np.float32))
|
input_arr = Tensor(np.zeros([1, 3, 227, 227], np.float32))
|
||||||
|
|
|
@ -1,124 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
"""Parse arguments"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import ast
|
|
||||||
import argparse
|
|
||||||
from pprint import pformat
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
_config_path = "./squeezenet_cifar10_config.yaml"
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""
|
|
||||||
Configuration namespace. Convert dictionary to members.
|
|
||||||
"""
|
|
||||||
def __init__(self, cfg_dict):
|
|
||||||
for k, v in cfg_dict.items():
|
|
||||||
if isinstance(v, (list, tuple)):
|
|
||||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
|
||||||
else:
|
|
||||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return pformat(self.__dict__)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return self.__str__()
|
|
||||||
|
|
||||||
|
|
||||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="squeezenet_cifar10_config.yaml"):
|
|
||||||
"""
|
|
||||||
Parse command line arguments to the configuration according to the default yaml.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parser: Parent parser.
|
|
||||||
cfg: Base configuration.
|
|
||||||
helper: Helper description.
|
|
||||||
cfg_path: Path to the default yaml config.
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
|
||||||
parents=[parser])
|
|
||||||
helper = {} if helper is None else helper
|
|
||||||
choices = {} if choices is None else choices
|
|
||||||
for item in cfg:
|
|
||||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
|
||||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
|
||||||
choice = choices[item] if item in choices else None
|
|
||||||
if isinstance(cfg[item], bool):
|
|
||||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
|
||||||
help=help_description)
|
|
||||||
else:
|
|
||||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
|
||||||
help=help_description)
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def parse_yaml(yaml_path):
|
|
||||||
"""
|
|
||||||
Parse the yaml config file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
yaml_path: Path to the yaml config.
|
|
||||||
"""
|
|
||||||
with open(yaml_path, 'r') as fin:
|
|
||||||
try:
|
|
||||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
|
||||||
cfgs = [x for x in cfgs]
|
|
||||||
if len(cfgs) == 1:
|
|
||||||
cfg_helper = {}
|
|
||||||
cfg = cfgs[0]
|
|
||||||
elif len(cfgs) == 2:
|
|
||||||
cfg, cfg_helper = cfgs
|
|
||||||
else:
|
|
||||||
raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml")
|
|
||||||
print(cfg_helper)
|
|
||||||
except:
|
|
||||||
raise ValueError("Failed to parse yaml")
|
|
||||||
return cfg, cfg_helper
|
|
||||||
|
|
||||||
|
|
||||||
def merge(args, cfg):
|
|
||||||
"""
|
|
||||||
Merge the base config from yaml file and command line arguments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Command line arguments.
|
|
||||||
cfg: Base configuration.
|
|
||||||
"""
|
|
||||||
args_var = vars(args)
|
|
||||||
for item in args_var:
|
|
||||||
cfg[item] = args_var[item]
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def get_config():
|
|
||||||
"""
|
|
||||||
Get Config according to the yaml file and cli arguments.
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, \
|
|
||||||
"../squeezenet_cifar10_config.yaml"), help="Config file path")
|
|
||||||
path_args, _ = parser.parse_known_args()
|
|
||||||
default, helper = parse_yaml(path_args.config_path)
|
|
||||||
args = parse_cli_to_yaml(parser, default, helper, path_args.config_path)
|
|
||||||
final_config = merge(args, default)
|
|
||||||
return Config(final_config)
|
|
||||||
|
|
||||||
config = get_config()
|
|
|
@ -1,27 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
"""Device adapter for ModelArts"""
|
|
||||||
|
|
||||||
from model_utils.config import config
|
|
||||||
|
|
||||||
if config.enable_modelarts:
|
|
||||||
from model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
|
||||||
else:
|
|
||||||
from model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
|
||||||
]
|
|
|
@ -1,36 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
"""Local adapter"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
def get_device_id():
|
|
||||||
device_id = os.getenv('DEVICE_ID', '0')
|
|
||||||
return int(device_id)
|
|
||||||
|
|
||||||
|
|
||||||
def get_device_num():
|
|
||||||
device_num = os.getenv('RANK_SIZE', '1')
|
|
||||||
return int(device_num)
|
|
||||||
|
|
||||||
|
|
||||||
def get_rank_id():
|
|
||||||
global_rank_id = os.getenv('RANK_ID', '0')
|
|
||||||
return int(global_rank_id)
|
|
||||||
|
|
||||||
|
|
||||||
def get_job_id():
|
|
||||||
return "Local Job"
|
|
|
@ -1,115 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
"""Moxing adapter for ModelArts"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import functools
|
|
||||||
from mindspore import context
|
|
||||||
from model_utils.config import config
|
|
||||||
|
|
||||||
_global_sync_count = 0
|
|
||||||
|
|
||||||
def get_device_id():
|
|
||||||
device_id = os.getenv('DEVICE_ID', '0')
|
|
||||||
return int(device_id)
|
|
||||||
|
|
||||||
|
|
||||||
def get_device_num():
|
|
||||||
device_num = os.getenv('RANK_SIZE', '1')
|
|
||||||
return int(device_num)
|
|
||||||
|
|
||||||
|
|
||||||
def get_rank_id():
|
|
||||||
global_rank_id = os.getenv('RANK_ID', '0')
|
|
||||||
return int(global_rank_id)
|
|
||||||
|
|
||||||
|
|
||||||
def get_job_id():
|
|
||||||
job_id = os.getenv('JOB_ID')
|
|
||||||
job_id = job_id if job_id != "" else "default"
|
|
||||||
return job_id
|
|
||||||
|
|
||||||
def sync_data(from_path, to_path):
|
|
||||||
"""
|
|
||||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
|
||||||
Upload data from local directory to remote obs in contrast.
|
|
||||||
"""
|
|
||||||
import moxing as mox
|
|
||||||
import time
|
|
||||||
global _global_sync_count
|
|
||||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
|
||||||
_global_sync_count += 1
|
|
||||||
|
|
||||||
# Each server contains 8 devices as most.
|
|
||||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
|
||||||
print("from path: ", from_path)
|
|
||||||
print("to path: ", to_path)
|
|
||||||
mox.file.copy_parallel(from_path, to_path)
|
|
||||||
print("===finish data synchronization===")
|
|
||||||
try:
|
|
||||||
os.mknod(sync_lock)
|
|
||||||
except IOError:
|
|
||||||
pass
|
|
||||||
print("===save flag===")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if os.path.exists(sync_lock):
|
|
||||||
break
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
|
||||||
|
|
||||||
|
|
||||||
def moxing_wrapper(pre_process=None, post_process=None):
|
|
||||||
"""
|
|
||||||
Moxing wrapper to download dataset and upload outputs.
|
|
||||||
"""
|
|
||||||
def wrapper(run_func):
|
|
||||||
@functools.wraps(run_func)
|
|
||||||
def wrapped_func(*args, **kwargs):
|
|
||||||
# Download data from data_url
|
|
||||||
if config.enable_modelarts:
|
|
||||||
if config.data_url:
|
|
||||||
sync_data(config.data_url, config.data_path)
|
|
||||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
|
||||||
if config.checkpoint_url:
|
|
||||||
sync_data(config.checkpoint_url, config.load_path)
|
|
||||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
|
||||||
if config.train_url:
|
|
||||||
sync_data(config.train_url, config.output_path)
|
|
||||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
|
||||||
|
|
||||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
|
||||||
config.device_num = get_device_num()
|
|
||||||
config.device_id = get_device_id()
|
|
||||||
if not os.path.exists(config.output_path):
|
|
||||||
os.makedirs(config.output_path)
|
|
||||||
|
|
||||||
if pre_process:
|
|
||||||
pre_process()
|
|
||||||
|
|
||||||
run_func(*args, **kwargs)
|
|
||||||
|
|
||||||
# Upload data to train_url
|
|
||||||
if config.enable_modelarts:
|
|
||||||
if post_process:
|
|
||||||
post_process()
|
|
||||||
|
|
||||||
if config.train_url:
|
|
||||||
print("Start to copy output directory")
|
|
||||||
sync_data(config.output_path, config.train_url)
|
|
||||||
return wrapped_func
|
|
||||||
return wrapper
|
|
|
@ -74,22 +74,6 @@ export RANK_TABLE_FILE=$PATH1
|
||||||
export SERVER_ID=0
|
export SERVER_ID=0
|
||||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||||
|
|
||||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
|
|
||||||
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
|
|
||||||
else
|
|
||||||
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
for((i=0; i<${DEVICE_NUM}; i++))
|
for((i=0; i<${DEVICE_NUM}; i++))
|
||||||
do
|
do
|
||||||
export DEVICE_ID=${i}
|
export DEVICE_ID=${i}
|
||||||
|
@ -98,21 +82,17 @@ do
|
||||||
mkdir ./train_parallel$i
|
mkdir ./train_parallel$i
|
||||||
cp ./train.py ./train_parallel$i
|
cp ./train.py ./train_parallel$i
|
||||||
cp -r ./src ./train_parallel$i
|
cp -r ./src ./train_parallel$i
|
||||||
cp -r ./model_utils ./train_parallel$i
|
|
||||||
cp -r ./*.yaml ./train_parallel$i
|
|
||||||
cd ./train_parallel$i || exit
|
cd ./train_parallel$i || exit
|
||||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
env > env.log
|
env > env.log
|
||||||
if [ $# == 4 ]
|
if [ $# == 4 ]
|
||||||
then
|
then
|
||||||
python train.py --net_name=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --data_path=$PATH2 \
|
python train.py --net=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
|
||||||
--config_path=$CONFIG_FILE --output_path './output' &> log &
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $# == 5 ]
|
if [ $# == 5 ]
|
||||||
then
|
then
|
||||||
python train.py --net_name=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --data_path=$PATH2 \
|
python train.py --net=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log &
|
||||||
--pre_trained=$PATH3 --config_path=$CONFIG_FILE --output_path './output' &> log &
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
|
@ -64,42 +64,22 @@ ulimit -u unlimited
|
||||||
export DEVICE_NUM=8
|
export DEVICE_NUM=8
|
||||||
export RANK_SIZE=8
|
export RANK_SIZE=8
|
||||||
|
|
||||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
|
|
||||||
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
|
|
||||||
else
|
|
||||||
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
rm -rf ./train_parallel
|
rm -rf ./train_parallel
|
||||||
mkdir ./train_parallel
|
mkdir ./train_parallel
|
||||||
cp ./train.py ./train_parallel
|
cp ./train.py ./train_parallel
|
||||||
cp -r ./src ./train_parallel
|
cp -r ./src ./train_parallel
|
||||||
cp -r ./model_utils ./train_parallel
|
|
||||||
cp -r ./*.yaml ./train_parallel
|
|
||||||
cd ./train_parallel || exit
|
cd ./train_parallel || exit
|
||||||
|
|
||||||
if [ $# == 3 ]
|
if [ $# == 3 ]
|
||||||
then
|
then
|
||||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||||
python train.py --net_name=$1 --dataset=$2 --run_distribute=True \
|
python train.py --net=$1 --dataset=$2 --run_distribute=True \
|
||||||
--device_num=$DEVICE_NUM --device_target="GPU" --data_path=$PATH1 \
|
--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log &
|
||||||
--config_path=$CONFIG_FILE --output_path './output' &> log &
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $# == 4 ]
|
if [ $# == 4 ]
|
||||||
then
|
then
|
||||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||||
python train.py --net_name=$1 --dataset=$2 --run_distribute=True \
|
python train.py --net=$1 --dataset=$2 --run_distribute=True \
|
||||||
--device_num=$DEVICE_NUM --device_target="GPU" --data_path=$PATH1 --pre_trained=$PATH2 \
|
--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
|
||||||
--config_path=$CONFIG_FILE --output_path './output' &> log &
|
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -62,22 +62,6 @@ export DEVICE_ID=$3
|
||||||
export RANK_SIZE=$DEVICE_NUM
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
export RANK_ID=0
|
export RANK_ID=0
|
||||||
|
|
||||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
|
|
||||||
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
|
|
||||||
else
|
|
||||||
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -d "eval" ];
|
if [ -d "eval" ];
|
||||||
then
|
then
|
||||||
rm -rf ./eval
|
rm -rf ./eval
|
||||||
|
@ -85,11 +69,8 @@ fi
|
||||||
mkdir ./eval
|
mkdir ./eval
|
||||||
cp ./eval.py ./eval
|
cp ./eval.py ./eval
|
||||||
cp -r ./src ./eval
|
cp -r ./src ./eval
|
||||||
cp -r ./model_utils ./eval
|
|
||||||
cp -r ./*.yaml ./eval
|
|
||||||
cd ./eval || exit
|
cd ./eval || exit
|
||||||
env > env.log
|
env > env.log
|
||||||
echo "start evaluation for device $DEVICE_ID"
|
echo "start evaluation for device $DEVICE_ID"
|
||||||
python eval.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --checkpoint_file_path=$PATH2 \
|
python eval.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
|
||||||
--config_path=$CONFIG_FILE --output_path './output' &> log &
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
|
@ -62,22 +62,6 @@ export DEVICE_ID=$3
|
||||||
export RANK_SIZE=$DEVICE_NUM
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
export RANK_ID=0
|
export RANK_ID=0
|
||||||
|
|
||||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
|
|
||||||
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
|
|
||||||
else
|
|
||||||
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -d "eval" ];
|
if [ -d "eval" ];
|
||||||
then
|
then
|
||||||
rm -rf ./eval
|
rm -rf ./eval
|
||||||
|
@ -85,11 +69,8 @@ fi
|
||||||
mkdir ./eval
|
mkdir ./eval
|
||||||
cp ./eval.py ./eval
|
cp ./eval.py ./eval
|
||||||
cp -r ./src ./eval
|
cp -r ./src ./eval
|
||||||
cp -r ./model_utils ./eval
|
|
||||||
cp -r ./*.yaml ./eval
|
|
||||||
cd ./eval || exit
|
cd ./eval || exit
|
||||||
env > env.log
|
env > env.log
|
||||||
echo "start evaluation for device $DEVICE_ID"
|
echo "start evaluation for device $DEVICE_ID"
|
||||||
python eval.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --checkpoint_file_path=$PATH2 --device_target="GPU" \
|
python eval.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log &
|
||||||
--config_path=$CONFIG_FILE --output_path './output' &> log &
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
|
@ -65,22 +65,6 @@ export DEVICE_ID=$3
|
||||||
export RANK_ID=0
|
export RANK_ID=0
|
||||||
export RANK_SIZE=1
|
export RANK_SIZE=1
|
||||||
|
|
||||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
|
|
||||||
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
|
|
||||||
else
|
|
||||||
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -d "train" ];
|
if [ -d "train" ];
|
||||||
then
|
then
|
||||||
rm -rf ./train
|
rm -rf ./train
|
||||||
|
@ -88,19 +72,16 @@ fi
|
||||||
mkdir ./train
|
mkdir ./train
|
||||||
cp ./train.py ./train
|
cp ./train.py ./train
|
||||||
cp -r ./src ./train
|
cp -r ./src ./train
|
||||||
cp -r ./model_utils ./train
|
|
||||||
cp -r ./*.yaml ./train
|
|
||||||
cd ./train || exit
|
cd ./train || exit
|
||||||
echo "start training for device $DEVICE_ID"
|
echo "start training for device $DEVICE_ID"
|
||||||
env > env.log
|
env > env.log
|
||||||
if [ $# == 4 ]
|
if [ $# == 4 ]
|
||||||
then
|
then
|
||||||
python train.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --config_path=$CONFIG_FILE &> log &
|
python train.py --net=$1 --dataset=$2 --dataset_path=$PATH1 &> log &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $# == 5 ]
|
if [ $# == 5 ]
|
||||||
then
|
then
|
||||||
python train.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --pre_trained=$PATH2 --config_path=$CONFIG_FILE \
|
python train.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
|
||||||
--output_path './output' &> log &
|
|
||||||
fi
|
fi
|
||||||
cd ..
|
cd ..
|
||||||
|
|
|
@ -65,22 +65,6 @@ export DEVICE_ID=$3
|
||||||
export RANK_ID=0
|
export RANK_ID=0
|
||||||
export RANK_SIZE=1
|
export RANK_SIZE=1
|
||||||
|
|
||||||
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
|
|
||||||
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
|
|
||||||
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
|
|
||||||
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
|
|
||||||
else
|
|
||||||
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -d "train" ];
|
if [ -d "train" ];
|
||||||
then
|
then
|
||||||
rm -rf ./train
|
rm -rf ./train
|
||||||
|
@ -88,20 +72,16 @@ fi
|
||||||
mkdir ./train
|
mkdir ./train
|
||||||
cp ./train.py ./train
|
cp ./train.py ./train
|
||||||
cp -r ./src ./train
|
cp -r ./src ./train
|
||||||
cp -r ./model_utils ./train
|
|
||||||
cp -r ./*.yaml ./train
|
|
||||||
cd ./train || exit
|
cd ./train || exit
|
||||||
echo "start training for device $DEVICE_ID"
|
echo "start training for device $DEVICE_ID"
|
||||||
env > env.log
|
env > env.log
|
||||||
if [ $# == 4 ]
|
if [ $# == 4 ]
|
||||||
then
|
then
|
||||||
python train.py --net_name=$1 --dataset=$2 --device_target="GPU" --data_path=$PATH1 \
|
python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 &> log &
|
||||||
--config_path=$CONFIG_FILE --output_path './output' &> log &
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $# == 5 ]
|
if [ $# == 5 ]
|
||||||
then
|
then
|
||||||
python train.py --net_name=$1 --dataset=$2 --device_target="GPU" --data_path=$PATH1 --pre_trained=$PATH2 \
|
python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
|
||||||
--config_path=$CONFIG_FILE --output_path './output' &> log &
|
|
||||||
fi
|
fi
|
||||||
cd ..
|
cd ..
|
||||||
|
|
|
@ -1,60 +0,0 @@
|
||||||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
|
||||||
enable_modelarts: False
|
|
||||||
# Url for modelarts
|
|
||||||
data_url: ""
|
|
||||||
train_url: ""
|
|
||||||
checkpoint_url: ""
|
|
||||||
# Path for local
|
|
||||||
run_distribute: False
|
|
||||||
enable_profiling: False
|
|
||||||
data_path: "/cache/data"
|
|
||||||
output_path: "/cache/train"
|
|
||||||
load_path: "/cache/checkpoint_path/"
|
|
||||||
device_num: 1
|
|
||||||
device_id: 0
|
|
||||||
device_target: 'Ascend'
|
|
||||||
checkpoint_path: './checkpoint/'
|
|
||||||
checkpoint_file_path: 'suqeezenet_cifar10-120_195.ckpt'
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Training options
|
|
||||||
net_name: ""
|
|
||||||
dataset : "cifar10"
|
|
||||||
class_num: 10
|
|
||||||
batch_size: 32
|
|
||||||
loss_scale: 1024
|
|
||||||
momentum: 0.9
|
|
||||||
weight_decay: 0.0001
|
|
||||||
epoch_size: 120
|
|
||||||
pretrain_epoch_size: 0
|
|
||||||
save_checkpoint: True
|
|
||||||
save_checkpoint_epochs: 1
|
|
||||||
keep_checkpoint_max: 10
|
|
||||||
warmup_epochs: 5
|
|
||||||
lr_decay_mode: "poly"
|
|
||||||
lr_init: 0
|
|
||||||
lr_end: 0
|
|
||||||
lr_max: 0.01
|
|
||||||
pre_trained: ""
|
|
||||||
|
|
||||||
# export
|
|
||||||
file_name: "squeezenet"
|
|
||||||
file_format: "AIR"
|
|
||||||
|
|
||||||
---
|
|
||||||
# Help description for each configuration
|
|
||||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
|
||||||
data_url: 'Dataset url for obs'
|
|
||||||
train_url: 'Training output url for obs'
|
|
||||||
checkpoint_url: 'The location of checkpoint for obs'
|
|
||||||
data_path: 'Dataset path for local'
|
|
||||||
output_path: 'Training output path for local'
|
|
||||||
load_path: 'The location of checkpoint for obs'
|
|
||||||
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
|
||||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
|
||||||
num_classes: 'Class for dataset'
|
|
||||||
batch_size: "Batch size for training and evaluation"
|
|
||||||
epoch_size: "Total training epochs."
|
|
||||||
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
|
||||||
checkpoint_path: "The location of the checkpoint file."
|
|
||||||
checkpoint_file_path: "The location of the checkpoint file."
|
|
|
@ -1,62 +0,0 @@
|
||||||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
|
||||||
enable_modelarts: False
|
|
||||||
# Url for modelarts
|
|
||||||
data_url: ""
|
|
||||||
train_url: ""
|
|
||||||
checkpoint_url: ""
|
|
||||||
# Path for local
|
|
||||||
run_distribute: False
|
|
||||||
enable_profiling: False
|
|
||||||
data_path: "/cache/data"
|
|
||||||
output_path: "/cache/train"
|
|
||||||
load_path: "/cache/checkpoint_path/"
|
|
||||||
device_num: 1
|
|
||||||
device_id: 0
|
|
||||||
device_target: 'Ascend'
|
|
||||||
checkpoint_path: './checkpoint/'
|
|
||||||
checkpoint_file_path: 'suqeezenet_imagenet-200_5004.ckpt'
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Training options
|
|
||||||
net_name: ""
|
|
||||||
dataset : "imagenet"
|
|
||||||
class_num: 1000
|
|
||||||
batch_size: 32
|
|
||||||
loss_scale: 1024
|
|
||||||
momentum: 0.9
|
|
||||||
weight_decay: 0.00007
|
|
||||||
epoch_size: 200
|
|
||||||
pretrain_epoch_size: 0
|
|
||||||
save_checkpoint: True
|
|
||||||
save_checkpoint_epochs: 1
|
|
||||||
keep_checkpoint_max: 10
|
|
||||||
warmup_epochs: 0
|
|
||||||
lr_decay_mode: "poly"
|
|
||||||
use_label_smooth: True
|
|
||||||
label_smooth_factor: 0.1
|
|
||||||
lr_init: 0
|
|
||||||
lr_end: 0
|
|
||||||
lr_max: 0.01
|
|
||||||
pre_trained: ""
|
|
||||||
|
|
||||||
# export
|
|
||||||
file_name: "squeezenet"
|
|
||||||
file_format: "AIR"
|
|
||||||
|
|
||||||
---
|
|
||||||
# Help description for each configuration
|
|
||||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
|
||||||
data_url: 'Dataset url for obs'
|
|
||||||
train_url: 'Training output url for obs'
|
|
||||||
checkpoint_url: 'The location of checkpoint for obs'
|
|
||||||
data_path: 'Dataset path for local'
|
|
||||||
output_path: 'Training output path for local'
|
|
||||||
load_path: 'The location of checkpoint for obs'
|
|
||||||
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
|
||||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
|
||||||
num_classes: 'Class for dataset'
|
|
||||||
batch_size: "Batch size for training and evaluation"
|
|
||||||
epoch_size: "Total training epochs."
|
|
||||||
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
|
||||||
checkpoint_path: "The location of the checkpoint file."
|
|
||||||
checkpoint_file_path: "The location of the checkpoint file."
|
|
|
@ -1,59 +0,0 @@
|
||||||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
|
||||||
enable_modelarts: False
|
|
||||||
# Url for modelarts
|
|
||||||
data_url: ""
|
|
||||||
train_url: ""
|
|
||||||
checkpoint_url: ""
|
|
||||||
# Path for local
|
|
||||||
run_distribute: False
|
|
||||||
enable_profiling: False
|
|
||||||
data_path: "/cache/data"
|
|
||||||
output_path: "/cache/train"
|
|
||||||
load_path: "/cache/checkpoint_path/"
|
|
||||||
device_num: 1
|
|
||||||
device_target: 'Ascend'
|
|
||||||
checkpoint_path: './checkpoint/'
|
|
||||||
checkpoint_file_path: 'suqeezenet_residual_cifar10-150_195.ckpt'
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Training options
|
|
||||||
net_name: ""
|
|
||||||
dataset : "cifar10"
|
|
||||||
class_num: 10
|
|
||||||
batch_size: 32
|
|
||||||
loss_scale: 1024
|
|
||||||
momentum: 0.9
|
|
||||||
weight_decay: 0.0001
|
|
||||||
epoch_size: 150
|
|
||||||
pretrain_epoch_size: 0
|
|
||||||
save_checkpoint: True
|
|
||||||
save_checkpoint_epochs: 1
|
|
||||||
keep_checkpoint_max: 10
|
|
||||||
warmup_epochs: 5
|
|
||||||
lr_decay_mode: "linear"
|
|
||||||
lr_init: 0
|
|
||||||
lr_end: 0
|
|
||||||
lr_max: 0.01
|
|
||||||
pre_trained: ""
|
|
||||||
|
|
||||||
#export
|
|
||||||
file_name: "squeezenet"
|
|
||||||
file_format: "AIR"
|
|
||||||
|
|
||||||
---
|
|
||||||
# Help description for each configuration
|
|
||||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
|
||||||
data_url: 'Dataset url for obs'
|
|
||||||
train_url: 'Training output url for obs'
|
|
||||||
checkpoint_url: 'The location of checkpoint for obs'
|
|
||||||
data_path: 'Dataset path for local'
|
|
||||||
output_path: 'Training output path for local'
|
|
||||||
load_path: 'The location of checkpoint for obs'
|
|
||||||
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
|
||||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
|
||||||
num_classes: 'Class for dataset'
|
|
||||||
batch_size: "Batch size for training and evaluation"
|
|
||||||
epoch_size: "Total training epochs."
|
|
||||||
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
|
||||||
checkpoint_path: "The location of the checkpoint file."
|
|
||||||
checkpoint_file_path: "The location of the checkpoint file."
|
|
|
@ -1,62 +0,0 @@
|
||||||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
|
||||||
enable_modelarts: False
|
|
||||||
# Url for modelarts
|
|
||||||
data_url: ""
|
|
||||||
train_url: ""
|
|
||||||
checkpoint_url: ""
|
|
||||||
# Path for local
|
|
||||||
run_distribute: False
|
|
||||||
enable_profiling: False
|
|
||||||
data_path: "/cache/data"
|
|
||||||
output_path: "/cache/train"
|
|
||||||
load_path: "/cache/checkpoint_path/"
|
|
||||||
device_num: 1
|
|
||||||
device_id: 0
|
|
||||||
device_target: 'Ascend'
|
|
||||||
checkpoint_path: './checkpoint/'
|
|
||||||
checkpoint_file_path: 'suqeezenet_residual_imagenet-300_5004.ckpt'
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Training options
|
|
||||||
net_name: ""
|
|
||||||
dataset : "imagenet"
|
|
||||||
class_num: 1000
|
|
||||||
batch_size: 32
|
|
||||||
loss_scale: 1024
|
|
||||||
momentum: 0.9
|
|
||||||
weight_decay: 0.00007
|
|
||||||
epoch_size: 300
|
|
||||||
pretrain_epoch_size: 0
|
|
||||||
save_checkpoint: True
|
|
||||||
save_checkpoint_epochs: 1
|
|
||||||
keep_checkpoint_max: 10
|
|
||||||
warmup_epochs: 0
|
|
||||||
lr_decay_mode: "cosine"
|
|
||||||
use_label_smooth: True
|
|
||||||
label_smooth_factor: 0.1
|
|
||||||
lr_init: 0
|
|
||||||
lr_end: 0
|
|
||||||
lr_max: 0.01
|
|
||||||
pre_trained: ""
|
|
||||||
|
|
||||||
#export
|
|
||||||
file_name: "squeezenet"
|
|
||||||
file_format: "AIR"
|
|
||||||
|
|
||||||
---
|
|
||||||
# Help description for each configuration
|
|
||||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
|
||||||
data_url: 'Dataset url for obs'
|
|
||||||
train_url: 'Training output url for obs'
|
|
||||||
checkpoint_url: 'The location of checkpoint for obs'
|
|
||||||
data_path: 'Dataset path for local'
|
|
||||||
output_path: 'Training output path for local'
|
|
||||||
load_path: 'The location of checkpoint for obs'
|
|
||||||
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
|
|
||||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
|
||||||
num_classes: 'Class for dataset'
|
|
||||||
batch_size: "Batch size for training and evaluation"
|
|
||||||
epoch_size: "Total training epochs."
|
|
||||||
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
|
|
||||||
checkpoint_path: "The location of the checkpoint file."
|
|
||||||
checkpoint_file_path: "The location of the checkpoint file."
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in train.py and eval.py
|
||||||
|
"""
|
||||||
|
from easydict import EasyDict as ed
|
||||||
|
|
||||||
|
# config for squeezenet, cifar10
|
||||||
|
config1 = ed({
|
||||||
|
"class_num": 10,
|
||||||
|
"batch_size": 32,
|
||||||
|
"loss_scale": 1024,
|
||||||
|
"momentum": 0.9,
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
"epoch_size": 120,
|
||||||
|
"pretrain_epoch_size": 0,
|
||||||
|
"save_checkpoint": True,
|
||||||
|
"save_checkpoint_epochs": 1,
|
||||||
|
"keep_checkpoint_max": 10,
|
||||||
|
"save_checkpoint_path": "./",
|
||||||
|
"warmup_epochs": 5,
|
||||||
|
"lr_decay_mode": "poly",
|
||||||
|
"lr_init": 0,
|
||||||
|
"lr_end": 0,
|
||||||
|
"lr_max": 0.01
|
||||||
|
})
|
||||||
|
|
||||||
|
# config for squeezenet, imagenet
|
||||||
|
config2 = ed({
|
||||||
|
"class_num": 1000,
|
||||||
|
"batch_size": 32,
|
||||||
|
"loss_scale": 1024,
|
||||||
|
"momentum": 0.9,
|
||||||
|
"weight_decay": 7e-5,
|
||||||
|
"epoch_size": 200,
|
||||||
|
"pretrain_epoch_size": 0,
|
||||||
|
"save_checkpoint": True,
|
||||||
|
"save_checkpoint_epochs": 1,
|
||||||
|
"keep_checkpoint_max": 10,
|
||||||
|
"save_checkpoint_path": "./",
|
||||||
|
"warmup_epochs": 0,
|
||||||
|
"lr_decay_mode": "poly",
|
||||||
|
"use_label_smooth": True,
|
||||||
|
"label_smooth_factor": 0.1,
|
||||||
|
"lr_init": 0,
|
||||||
|
"lr_end": 0,
|
||||||
|
"lr_max": 0.01
|
||||||
|
})
|
||||||
|
|
||||||
|
# config for squeezenet_residual, cifar10
|
||||||
|
config3 = ed({
|
||||||
|
"class_num": 10,
|
||||||
|
"batch_size": 32,
|
||||||
|
"loss_scale": 1024,
|
||||||
|
"momentum": 0.9,
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
"epoch_size": 150,
|
||||||
|
"pretrain_epoch_size": 0,
|
||||||
|
"save_checkpoint": True,
|
||||||
|
"save_checkpoint_epochs": 1,
|
||||||
|
"keep_checkpoint_max": 10,
|
||||||
|
"save_checkpoint_path": "./",
|
||||||
|
"warmup_epochs": 5,
|
||||||
|
"lr_decay_mode": "linear",
|
||||||
|
"lr_init": 0,
|
||||||
|
"lr_end": 0,
|
||||||
|
"lr_max": 0.01
|
||||||
|
})
|
||||||
|
|
||||||
|
# config for squeezenet_residual, imagenet
|
||||||
|
config4 = ed({
|
||||||
|
"class_num": 1000,
|
||||||
|
"batch_size": 32,
|
||||||
|
"loss_scale": 1024,
|
||||||
|
"momentum": 0.9,
|
||||||
|
"weight_decay": 7e-5,
|
||||||
|
"epoch_size": 300,
|
||||||
|
"pretrain_epoch_size": 0,
|
||||||
|
"save_checkpoint": True,
|
||||||
|
"save_checkpoint_epochs": 1,
|
||||||
|
"keep_checkpoint_max": 10,
|
||||||
|
"save_checkpoint_path": "./",
|
||||||
|
"warmup_epochs": 0,
|
||||||
|
"lr_decay_mode": "cosine",
|
||||||
|
"use_label_smooth": True,
|
||||||
|
"label_smooth_factor": 0.1,
|
||||||
|
"lr_init": 0,
|
||||||
|
"lr_end": 0,
|
||||||
|
"lr_max": 0.01
|
||||||
|
})
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""train squeezenet."""
|
"""train squeezenet."""
|
||||||
import os
|
import os
|
||||||
|
import argparse
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.nn.optim.momentum import Momentum
|
from mindspore.nn.optim.momentum import Momentum
|
||||||
|
@ -23,45 +24,55 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
|
||||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.communication.management import init, get_rank
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
from model_utils.config import config
|
|
||||||
from model_utils.device_adapter import get_device_num
|
|
||||||
from model_utils.moxing_adapter import moxing_wrapper
|
|
||||||
from src.lr_generator import get_lr
|
from src.lr_generator import get_lr
|
||||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
|
parser.add_argument('--net', type=str, default='squeezenet', choices=['squeezenet', 'squeezenet_residual'],
|
||||||
|
help='Model.')
|
||||||
|
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'], help='Dataset.')
|
||||||
|
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||||
|
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||||
|
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||||
|
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
if config.net_name == "squeezenet":
|
if args_opt.net == "squeezenet":
|
||||||
from src.squeezenet import SqueezeNet as squeezenet
|
from src.squeezenet import SqueezeNet as squeezenet
|
||||||
if config.dataset == "cifar10":
|
if args_opt.dataset == "cifar10":
|
||||||
|
from src.config import config1 as config
|
||||||
from src.dataset import create_dataset_cifar as create_dataset
|
from src.dataset import create_dataset_cifar as create_dataset
|
||||||
else:
|
else:
|
||||||
|
from src.config import config2 as config
|
||||||
from src.dataset import create_dataset_imagenet as create_dataset
|
from src.dataset import create_dataset_imagenet as create_dataset
|
||||||
else:
|
else:
|
||||||
from src.squeezenet import SqueezeNet_Residual as squeezenet
|
from src.squeezenet import SqueezeNet_Residual as squeezenet
|
||||||
if config.dataset == "cifar10":
|
if args_opt.dataset == "cifar10":
|
||||||
|
from src.config import config3 as config
|
||||||
from src.dataset import create_dataset_cifar as create_dataset
|
from src.dataset import create_dataset_cifar as create_dataset
|
||||||
else:
|
else:
|
||||||
|
from src.config import config4 as config
|
||||||
from src.dataset import create_dataset_imagenet as create_dataset
|
from src.dataset import create_dataset_imagenet as create_dataset
|
||||||
|
|
||||||
@moxing_wrapper()
|
if __name__ == '__main__':
|
||||||
def train_net():
|
target = args_opt.device_target
|
||||||
"""train net"""
|
ckpt_save_dir = config.save_checkpoint_path
|
||||||
target = config.device_target
|
|
||||||
ckpt_save_dir = config.output_path
|
|
||||||
|
|
||||||
# init context
|
# init context
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target=target)
|
device_target=target)
|
||||||
if config.run_distribute:
|
if args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(device_id=device_id,
|
context.set_context(device_id=device_id,
|
||||||
enable_auto_mixed_precision=True)
|
enable_auto_mixed_precision=True)
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
device_num=config.device_num,
|
device_num=args_opt.device_num,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
init()
|
init()
|
||||||
|
@ -69,13 +80,14 @@ def train_net():
|
||||||
else:
|
else:
|
||||||
init()
|
init()
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
device_num=get_device_num(),
|
device_num=get_group_size(),
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank()) + "/"
|
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(
|
||||||
|
get_rank()) + "/"
|
||||||
|
|
||||||
# create dataset
|
# create dataset
|
||||||
dataset = create_dataset(dataset_path=config.data_path,
|
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
repeat_num=1,
|
repeat_num=1,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
|
@ -86,8 +98,8 @@ def train_net():
|
||||||
net = squeezenet(num_classes=config.class_num)
|
net = squeezenet(num_classes=config.class_num)
|
||||||
|
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
if config.pre_trained:
|
if args_opt.pre_trained:
|
||||||
param_dict = load_checkpoint(config.pre_trained)
|
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||||
load_param_into_net(net, param_dict)
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
# init lr
|
# init lr
|
||||||
|
@ -102,7 +114,7 @@ def train_net():
|
||||||
lr = Tensor(lr)
|
lr = Tensor(lr)
|
||||||
|
|
||||||
# define loss
|
# define loss
|
||||||
if config.dataset == "imagenet":
|
if args_opt.dataset == "imagenet":
|
||||||
if not config.use_label_smooth:
|
if not config.use_label_smooth:
|
||||||
config.label_smooth_factor = 0.0
|
config.label_smooth_factor = 0.0
|
||||||
loss = CrossEntropySmooth(sparse=True,
|
loss = CrossEntropySmooth(sparse=True,
|
||||||
|
@ -146,7 +158,7 @@ def train_net():
|
||||||
config_ck = CheckpointConfig(
|
config_ck = CheckpointConfig(
|
||||||
save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
ckpt_cb = ModelCheckpoint(prefix=config.net_name + '_' + config.dataset,
|
ckpt_cb = ModelCheckpoint(prefix=args_opt.net + '_' + args_opt.dataset,
|
||||||
directory=ckpt_save_dir,
|
directory=ckpt_save_dir,
|
||||||
config=config_ck)
|
config=config_ck)
|
||||||
cb += [ckpt_cb]
|
cb += [ckpt_cb]
|
||||||
|
@ -155,6 +167,3 @@ def train_net():
|
||||||
model.train(config.epoch_size - config.pretrain_epoch_size,
|
model.train(config.epoch_size - config.pretrain_epoch_size,
|
||||||
dataset,
|
dataset,
|
||||||
callbacks=cb)
|
callbacks=cb)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
train_net()
|
|
||||||
|
|
Loading…
Reference in New Issue