!17951 modify model_zoo squeezenet

From: @Somnus2020
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-06-08 15:04:44 +08:00 committed by Gitee
commit 77e562db4f
19 changed files with 200 additions and 775 deletions

View File

@ -100,37 +100,6 @@ After installing MindSpore via the official website, you can start training and
sh scripts/run_eval_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH] sh scripts/run_eval_gpu.sh [squeezenet|squeezenet_residual] [cifar10|imagenet] [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]
``` ```
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
```python
# run distributed training on modelarts example
# (1) First, Perform a or b.
# a. Set "enable_modelarts=True" on yaml file.
# Set other parameters on yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add other parameters on the website UI interface.
# (2) Set the Dataset directory in config file.
# (3) Set the code directory to "/path/squeezenet" on the website UI interface.
# (4) Set the startup file to "train.py" on the website UI interface.
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (6) Create your job.
# run evaluation on modelarts example
# (1) Copy or upload your trained model to S3 bucket.
# (2) Perform a or b.
# a. Set "enable_modelarts=True" on yaml file.
# Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file.
# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
# (3) Set the Dataset directory in config file.
# (4) Set the code directory to "/path/squeezenet" on the website UI interface.
# (5) Set the startup file to "eval.py" on the website UI interface.
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (7) Create your job.
```
# [Script Description](#contents) # [Script Description](#contents)
## [Script and Sample Code](#contents) ## [Script and Sample Code](#contents)
@ -147,19 +116,11 @@ After installing MindSpore via the official website, you can start training and
├── run_eval.sh # launch ascend evaluation ├── run_eval.sh # launch ascend evaluation
└── run_eval_gpu.sh # launch gpu evaluation └── run_eval_gpu.sh # launch gpu evaluation
├── src ├── src
├── config.py # parameter configuration
├── dataset.py # data preprocessing ├── dataset.py # data preprocessing
├── CrossEntropySmooth.py # loss definition for ImageNet dataset ├── CrossEntropySmooth.py # loss definition for ImageNet dataset
├── lr_generator.py # generate learning rate for each step ├── lr_generator.py # generate learning rate for each step
└── squeezenet.py # squeezenet architecture, including squeezenet and squeezenet_residual └── squeezenet.py # squeezenet architecture, including squeezenet and squeezenet_residual
├── model_utils
│ ├── device_adapter.py # device adapter
│ ├── local_adapter.py # local adapter
│ ├── moxing_adapter.py # moxing adapter
│ ├── config.py # parameter analysis
├── squeezenet_cifar10_config.yaml # parameter configuration
├── squeezenet_imagenet_config.yaml # parameter configuration
├── squeezenet_residual_cifar10_config.yaml # parameter configuration
├── squeezenet_residual_imagenet_config.yaml # parameter configuration
├── train.py # train net ├── train.py # train net
├── eval.py # eval net ├── eval.py # eval net
└── export.py # export checkpoint files into geir/onnx └── export.py # export checkpoint files into geir/onnx

View File

@ -14,34 +14,44 @@
# ============================================================================ # ============================================================================
"""eval squeezenet.""" """eval squeezenet."""
import os import os
import argparse
from mindspore import context from mindspore import context
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
from src.CrossEntropySmooth import CrossEntropySmooth from src.CrossEntropySmooth import CrossEntropySmooth
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default='squeezenet', choices=['squeezenet', 'squeezenet_residual'],
help='Model.')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'], help='Dataset.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
args_opt = parser.parse_args()
set_seed(1) set_seed(1)
if config.net_name == "squeezenet": if args_opt.net == "squeezenet":
from src.squeezenet import SqueezeNet as squeezenet from src.squeezenet import SqueezeNet as squeezenet
if config.dataset == "cifar10": if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset_cifar as create_dataset from src.dataset import create_dataset_cifar as create_dataset
else: else:
from src.config import config2 as config
from src.dataset import create_dataset_imagenet as create_dataset from src.dataset import create_dataset_imagenet as create_dataset
else: else:
from src.squeezenet import SqueezeNet_Residual as squeezenet from src.squeezenet import SqueezeNet_Residual as squeezenet
if config.dataset == "cifar10": if args_opt.dataset == "cifar10":
from src.config import config3 as config
from src.dataset import create_dataset_cifar as create_dataset from src.dataset import create_dataset_cifar as create_dataset
else: else:
from src.config import config4 as config
from src.dataset import create_dataset_imagenet as create_dataset from src.dataset import create_dataset_imagenet as create_dataset
@moxing_wrapper() if __name__ == '__main__':
def eval_net(): target = args_opt.device_target
"""eval net """
target = config.device_target
# init context # init context
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
@ -50,21 +60,22 @@ def eval_net():
device_id=device_id) device_id=device_id)
# create dataset # create dataset
dataset = create_dataset(dataset_path=config.data_path, dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False, do_train=False,
batch_size=config.batch_size, batch_size=config.batch_size,
target=target) target=target)
step_size = dataset.get_dataset_size()
# define net # define net
net = squeezenet(num_classes=config.class_num) net = squeezenet(num_classes=config.class_num)
# load checkpoint # load checkpoint
param_dict = load_checkpoint(config.checkpoint_file_path) param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net.set_train(False) net.set_train(False)
# define loss # define loss
if config.dataset == "imagenet": if args_opt.dataset == "imagenet":
if not config.use_label_smooth: if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, loss = CrossEntropySmooth(sparse=True,
@ -81,7 +92,4 @@ def eval_net():
# eval model # eval model
res = model.eval(dataset) res = model.eval(dataset)
print("result:", res, "ckpt=", config.checkpoint_file_path) print("result:", res, "ckpt=", args_opt.checkpoint_path)
if __name__ == '__main__':
eval_net()

View File

@ -17,29 +17,36 @@
python export.py --net squeezenet --dataset cifar10 --checkpoint_path squeezenet_cifar10-120_1562.ckpt python export.py --net squeezenet --dataset cifar10 --checkpoint_path squeezenet_cifar10-120_1562.ckpt
""" """
import argparse
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from model_utils.config import config
if __name__ == '__main__': if __name__ == '__main__':
if config.net_name == "squeezenet": parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default='squeezenet', choices=['squeezenet', 'squeezenet_residual'],
help='Model.')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'], help='Dataset.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
args_opt = parser.parse_args()
if args_opt.net == "squeezenet":
from src.squeezenet import SqueezeNet as squeezenet from src.squeezenet import SqueezeNet as squeezenet
else: else:
from src.squeezenet import SqueezeNet_Residual as squeezenet from src.squeezenet import SqueezeNet_Residual as squeezenet
if config.dataset == "cifar10": if args_opt.dataset == "cifar10":
num_classes = 10 num_classes = 10
else: else:
num_classes = 1000 num_classes = 1000
onnx_filename = config.net_name + '_' + config.dataset onnx_filename = args_opt.net + '_' + args_opt.dataset
air_filename = config.net_name + '_' + config.dataset air_filename = args_opt.net + '_' + args_opt.dataset
net = squeezenet(num_classes=num_classes) net = squeezenet(num_classes=num_classes)
assert config.checkpoint_file_path is not None, "checkpoint_file_path is None." assert args_opt.checkpoint_path is not None, "checkpoint_path is None."
param_dict = load_checkpoint(config.checkpoint_file_path) param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([1, 3, 227, 227], np.float32)) input_arr = Tensor(np.zeros([1, 3, 227, 227], np.float32))

View File

@ -1,124 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pformat
import yaml
_config_path = "./squeezenet_cifar10_config.yaml"
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="squeezenet_cifar10_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
else:
raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, \
"../squeezenet_cifar10_config.yaml"), help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser, default, helper, path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -1,27 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from model_utils.config import config
if config.enable_modelarts:
from model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -1,36 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -1,115 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from model_utils.config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -74,22 +74,6 @@ export RANK_TABLE_FILE=$PATH1
export SERVER_ID=0 export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID)) rank_start=$((DEVICE_NUM * SERVER_ID))
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
else
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
exit 1
fi
for((i=0; i<${DEVICE_NUM}; i++)) for((i=0; i<${DEVICE_NUM}; i++))
do do
export DEVICE_ID=${i} export DEVICE_ID=${i}
@ -98,21 +82,17 @@ do
mkdir ./train_parallel$i mkdir ./train_parallel$i
cp ./train.py ./train_parallel$i cp ./train.py ./train_parallel$i
cp -r ./src ./train_parallel$i cp -r ./src ./train_parallel$i
cp -r ./model_utils ./train_parallel$i
cp -r ./*.yaml ./train_parallel$i
cd ./train_parallel$i || exit cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID" echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log env > env.log
if [ $# == 4 ] if [ $# == 4 ]
then then
python train.py --net_name=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --data_path=$PATH2 \ python train.py --net=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
--config_path=$CONFIG_FILE --output_path './output' &> log &
fi fi
if [ $# == 5 ] if [ $# == 5 ]
then then
python train.py --net_name=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --data_path=$PATH2 \ python train.py --net=$1 --dataset=$2 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log &
--pre_trained=$PATH3 --config_path=$CONFIG_FILE --output_path './output' &> log &
fi fi
cd .. cd ..

View File

@ -64,42 +64,22 @@ ulimit -u unlimited
export DEVICE_NUM=8 export DEVICE_NUM=8
export RANK_SIZE=8 export RANK_SIZE=8
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
else
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
exit 1
fi
rm -rf ./train_parallel rm -rf ./train_parallel
mkdir ./train_parallel mkdir ./train_parallel
cp ./train.py ./train_parallel cp ./train.py ./train_parallel
cp -r ./src ./train_parallel cp -r ./src ./train_parallel
cp -r ./model_utils ./train_parallel
cp -r ./*.yaml ./train_parallel
cd ./train_parallel || exit cd ./train_parallel || exit
if [ $# == 3 ] if [ $# == 3 ]
then then
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \ mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py --net_name=$1 --dataset=$2 --run_distribute=True \ python train.py --net=$1 --dataset=$2 --run_distribute=True \
--device_num=$DEVICE_NUM --device_target="GPU" --data_path=$PATH1 \ --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log &
--config_path=$CONFIG_FILE --output_path './output' &> log &
fi fi
if [ $# == 4 ] if [ $# == 4 ]
then then
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \ mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py --net_name=$1 --dataset=$2 --run_distribute=True \ python train.py --net=$1 --dataset=$2 --run_distribute=True \
--device_num=$DEVICE_NUM --device_target="GPU" --data_path=$PATH1 --pre_trained=$PATH2 \ --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
--config_path=$CONFIG_FILE --output_path './output' &> log &
fi fi

View File

@ -62,22 +62,6 @@ export DEVICE_ID=$3
export RANK_SIZE=$DEVICE_NUM export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0 export RANK_ID=0
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
else
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
exit 1
fi
if [ -d "eval" ]; if [ -d "eval" ];
then then
rm -rf ./eval rm -rf ./eval
@ -85,11 +69,8 @@ fi
mkdir ./eval mkdir ./eval
cp ./eval.py ./eval cp ./eval.py ./eval
cp -r ./src ./eval cp -r ./src ./eval
cp -r ./model_utils ./eval
cp -r ./*.yaml ./eval
cd ./eval || exit cd ./eval || exit
env > env.log env > env.log
echo "start evaluation for device $DEVICE_ID" echo "start evaluation for device $DEVICE_ID"
python eval.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --checkpoint_file_path=$PATH2 \ python eval.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
--config_path=$CONFIG_FILE --output_path './output' &> log &
cd .. cd ..

View File

@ -62,22 +62,6 @@ export DEVICE_ID=$3
export RANK_SIZE=$DEVICE_NUM export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0 export RANK_ID=0
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
else
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
exit 1
fi
if [ -d "eval" ]; if [ -d "eval" ];
then then
rm -rf ./eval rm -rf ./eval
@ -85,11 +69,8 @@ fi
mkdir ./eval mkdir ./eval
cp ./eval.py ./eval cp ./eval.py ./eval
cp -r ./src ./eval cp -r ./src ./eval
cp -r ./model_utils ./eval
cp -r ./*.yaml ./eval
cd ./eval || exit cd ./eval || exit
env > env.log env > env.log
echo "start evaluation for device $DEVICE_ID" echo "start evaluation for device $DEVICE_ID"
python eval.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --checkpoint_file_path=$PATH2 --device_target="GPU" \ python eval.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log &
--config_path=$CONFIG_FILE --output_path './output' &> log &
cd .. cd ..

View File

@ -65,22 +65,6 @@ export DEVICE_ID=$3
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
else
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
exit 1
fi
if [ -d "train" ]; if [ -d "train" ];
then then
rm -rf ./train rm -rf ./train
@ -88,19 +72,16 @@ fi
mkdir ./train mkdir ./train
cp ./train.py ./train cp ./train.py ./train
cp -r ./src ./train cp -r ./src ./train
cp -r ./model_utils ./train
cp -r ./*.yaml ./train
cd ./train || exit cd ./train || exit
echo "start training for device $DEVICE_ID" echo "start training for device $DEVICE_ID"
env > env.log env > env.log
if [ $# == 4 ] if [ $# == 4 ]
then then
python train.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --config_path=$CONFIG_FILE &> log & python train.py --net=$1 --dataset=$2 --dataset_path=$PATH1 &> log &
fi fi
if [ $# == 5 ] if [ $# == 5 ]
then then
python train.py --net_name=$1 --dataset=$2 --data_path=$PATH1 --pre_trained=$PATH2 --config_path=$CONFIG_FILE \ python train.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
--output_path './output' &> log &
fi fi
cd .. cd ..

View File

@ -65,22 +65,6 @@ export DEVICE_ID=$3
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
BASE_PATH=$(dirname "$(dirname "$(readlink -f $0)")")
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
if [ $1 == "squeezenet" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_cifar10_config.yaml"
elif [ $1 == "squeezenet" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_imagenet_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "cifar10" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_cifar10_config.yaml"
elif [ $1 == "squeezenet_residual" ] && [ $2 == "imagenet" ]; then
CONFIG_FILE="${BASE_PATH}/squeezenet_residual_imagenet_config.yaml"
else
echo "error: the selected dataset is not in supported set{squeezenet, squeezenet_residual, cifar10, imagenet}"
exit 1
fi
if [ -d "train" ]; if [ -d "train" ];
then then
rm -rf ./train rm -rf ./train
@ -88,20 +72,16 @@ fi
mkdir ./train mkdir ./train
cp ./train.py ./train cp ./train.py ./train
cp -r ./src ./train cp -r ./src ./train
cp -r ./model_utils ./train
cp -r ./*.yaml ./train
cd ./train || exit cd ./train || exit
echo "start training for device $DEVICE_ID" echo "start training for device $DEVICE_ID"
env > env.log env > env.log
if [ $# == 4 ] if [ $# == 4 ]
then then
python train.py --net_name=$1 --dataset=$2 --device_target="GPU" --data_path=$PATH1 \ python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 &> log &
--config_path=$CONFIG_FILE --output_path './output' &> log &
fi fi
if [ $# == 5 ] if [ $# == 5 ]
then then
python train.py --net_name=$1 --dataset=$2 --device_target="GPU" --data_path=$PATH1 --pre_trained=$PATH2 \ python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
--config_path=$CONFIG_FILE --output_path './output' &> log &
fi fi
cd .. cd ..

View File

@ -1,60 +0,0 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
run_distribute: False
enable_profiling: False
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_num: 1
device_id: 0
device_target: 'Ascend'
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'suqeezenet_cifar10-120_195.ckpt'
# ==============================================================================
# Training options
net_name: ""
dataset : "cifar10"
class_num: 10
batch_size: 32
loss_scale: 1024
momentum: 0.9
weight_decay: 0.0001
epoch_size: 120
pretrain_epoch_size: 0
save_checkpoint: True
save_checkpoint_epochs: 1
keep_checkpoint_max: 10
warmup_epochs: 5
lr_decay_mode: "poly"
lr_init: 0
lr_end: 0
lr_max: 0.01
pre_trained: ""
# export
file_name: "squeezenet"
file_format: "AIR"
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
epoch_size: "Total training epochs."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."

View File

@ -1,62 +0,0 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
run_distribute: False
enable_profiling: False
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_num: 1
device_id: 0
device_target: 'Ascend'
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'suqeezenet_imagenet-200_5004.ckpt'
# ==============================================================================
# Training options
net_name: ""
dataset : "imagenet"
class_num: 1000
batch_size: 32
loss_scale: 1024
momentum: 0.9
weight_decay: 0.00007
epoch_size: 200
pretrain_epoch_size: 0
save_checkpoint: True
save_checkpoint_epochs: 1
keep_checkpoint_max: 10
warmup_epochs: 0
lr_decay_mode: "poly"
use_label_smooth: True
label_smooth_factor: 0.1
lr_init: 0
lr_end: 0
lr_max: 0.01
pre_trained: ""
# export
file_name: "squeezenet"
file_format: "AIR"
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
epoch_size: "Total training epochs."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."

View File

@ -1,59 +0,0 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
run_distribute: False
enable_profiling: False
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_num: 1
device_target: 'Ascend'
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'suqeezenet_residual_cifar10-150_195.ckpt'
# ==============================================================================
# Training options
net_name: ""
dataset : "cifar10"
class_num: 10
batch_size: 32
loss_scale: 1024
momentum: 0.9
weight_decay: 0.0001
epoch_size: 150
pretrain_epoch_size: 0
save_checkpoint: True
save_checkpoint_epochs: 1
keep_checkpoint_max: 10
warmup_epochs: 5
lr_decay_mode: "linear"
lr_init: 0
lr_end: 0
lr_max: 0.01
pre_trained: ""
#export
file_name: "squeezenet"
file_format: "AIR"
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
epoch_size: "Total training epochs."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."

View File

@ -1,62 +0,0 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
run_distribute: False
enable_profiling: False
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path/"
device_num: 1
device_id: 0
device_target: 'Ascend'
checkpoint_path: './checkpoint/'
checkpoint_file_path: 'suqeezenet_residual_imagenet-300_5004.ckpt'
# ==============================================================================
# Training options
net_name: ""
dataset : "imagenet"
class_num: 1000
batch_size: 32
loss_scale: 1024
momentum: 0.9
weight_decay: 0.00007
epoch_size: 300
pretrain_epoch_size: 0
save_checkpoint: True
save_checkpoint_epochs: 1
keep_checkpoint_max: 10
warmup_epochs: 0
lr_decay_mode: "cosine"
use_label_smooth: True
label_smooth_factor: 0.1
lr_init: 0
lr_end: 0
lr_max: 0.01
pre_trained: ""
#export
file_name: "squeezenet"
file_format: "AIR"
---
# Help description for each configuration
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
checkpoint_url: 'The location of checkpoint for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
load_path: 'The location of checkpoint for obs'
device_target: 'Target device type, available: [Ascend, GPU, CPU]'
enable_profiling: 'Whether enable profiling while training, default: False'
num_classes: 'Class for dataset'
batch_size: "Batch size for training and evaluation"
epoch_size: "Total training epochs."
keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint"
checkpoint_path: "The location of the checkpoint file."
checkpoint_file_path: "The location of the checkpoint file."

View File

@ -0,0 +1,102 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
# config for squeezenet, cifar10
config1 = ed({
"class_num": 10,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 120,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 5,
"lr_decay_mode": "poly",
"lr_init": 0,
"lr_end": 0,
"lr_max": 0.01
})
# config for squeezenet, imagenet
config2 = ed({
"class_num": 1000,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 7e-5,
"epoch_size": 200,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "poly",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_end": 0,
"lr_max": 0.01
})
# config for squeezenet_residual, cifar10
config3 = ed({
"class_num": 10,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 150,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 5,
"lr_decay_mode": "linear",
"lr_init": 0,
"lr_end": 0,
"lr_max": 0.01
})
# config for squeezenet_residual, imagenet
config4 = ed({
"class_num": 1000,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 7e-5,
"epoch_size": 300,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_end": 0,
"lr_max": 0.01
})

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""train squeezenet.""" """train squeezenet."""
import os import os
import argparse
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
@ -23,45 +24,55 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.common import set_seed from mindspore.common import set_seed
from model_utils.config import config
from model_utils.device_adapter import get_device_num
from model_utils.moxing_adapter import moxing_wrapper
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.CrossEntropySmooth import CrossEntropySmooth from src.CrossEntropySmooth import CrossEntropySmooth
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default='squeezenet', choices=['squeezenet', 'squeezenet_residual'],
help='Model.')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet'], help='Dataset.')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
args_opt = parser.parse_args()
set_seed(1) set_seed(1)
if config.net_name == "squeezenet": if args_opt.net == "squeezenet":
from src.squeezenet import SqueezeNet as squeezenet from src.squeezenet import SqueezeNet as squeezenet
if config.dataset == "cifar10": if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset_cifar as create_dataset from src.dataset import create_dataset_cifar as create_dataset
else: else:
from src.config import config2 as config
from src.dataset import create_dataset_imagenet as create_dataset from src.dataset import create_dataset_imagenet as create_dataset
else: else:
from src.squeezenet import SqueezeNet_Residual as squeezenet from src.squeezenet import SqueezeNet_Residual as squeezenet
if config.dataset == "cifar10": if args_opt.dataset == "cifar10":
from src.config import config3 as config
from src.dataset import create_dataset_cifar as create_dataset from src.dataset import create_dataset_cifar as create_dataset
else: else:
from src.config import config4 as config
from src.dataset import create_dataset_imagenet as create_dataset from src.dataset import create_dataset_imagenet as create_dataset
@moxing_wrapper() if __name__ == '__main__':
def train_net(): target = args_opt.device_target
"""train net""" ckpt_save_dir = config.save_checkpoint_path
target = config.device_target
ckpt_save_dir = config.output_path
# init context # init context
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target=target) device_target=target)
if config.run_distribute: if args_opt.run_distribute:
if target == "Ascend": if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, context.set_context(device_id=device_id,
enable_auto_mixed_precision=True) enable_auto_mixed_precision=True)
context.set_auto_parallel_context( context.set_auto_parallel_context(
device_num=config.device_num, device_num=args_opt.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
init() init()
@ -69,13 +80,14 @@ def train_net():
else: else:
init() init()
context.set_auto_parallel_context( context.set_auto_parallel_context(
device_num=get_device_num(), device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank()) + "/" ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(
get_rank()) + "/"
# create dataset # create dataset
dataset = create_dataset(dataset_path=config.data_path, dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True, do_train=True,
repeat_num=1, repeat_num=1,
batch_size=config.batch_size, batch_size=config.batch_size,
@ -86,8 +98,8 @@ def train_net():
net = squeezenet(num_classes=config.class_num) net = squeezenet(num_classes=config.class_num)
# load checkpoint # load checkpoint
if config.pre_trained: if args_opt.pre_trained:
param_dict = load_checkpoint(config.pre_trained) param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
# init lr # init lr
@ -102,7 +114,7 @@ def train_net():
lr = Tensor(lr) lr = Tensor(lr)
# define loss # define loss
if config.dataset == "imagenet": if args_opt.dataset == "imagenet":
if not config.use_label_smooth: if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, loss = CrossEntropySmooth(sparse=True,
@ -146,7 +158,7 @@ def train_net():
config_ck = CheckpointConfig( config_ck = CheckpointConfig(
save_checkpoint_steps=config.save_checkpoint_epochs * step_size, save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max) keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=config.net_name + '_' + config.dataset, ckpt_cb = ModelCheckpoint(prefix=args_opt.net + '_' + args_opt.dataset,
directory=ckpt_save_dir, directory=ckpt_save_dir,
config=config_ck) config=config_ck)
cb += [ckpt_cb] cb += [ckpt_cb]
@ -155,6 +167,3 @@ def train_net():
model.train(config.epoch_size - config.pretrain_epoch_size, model.train(config.epoch_size - config.pretrain_epoch_size,
dataset, dataset,
callbacks=cb) callbacks=cb)
if __name__ == '__main__':
train_net()