align to docs and r1.0

This commit is contained in:
Payne 2020-09-17 13:08:32 +08:00
parent 93c4d2929c
commit c53cd6bb22
8 changed files with 96 additions and 51 deletions

View File

@ -13,13 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
run_ascend()
{
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
}
if [ $# -gt 4 ] || [ $# -lt 3 ]
then
echo "Ascend: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH] \
GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]"
echo "Usage:
Ascend: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]
GPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]
CPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]"
exit 1
fi
# check dataset path
if [ ! -d $2 ]
then
@ -30,16 +43,13 @@ fi
# check checkpoint file
if [ ! -f $3 ]
then
echo "error: CHECKPOINT_PATH=$3 is not a file"
echo "error: PRETRAIN_CKPT=$3 is not a file"
exit 1
fi
# set environment
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "../eval" ];
then
rm -rf ../eval
@ -47,9 +57,14 @@ fi
mkdir ../eval
cd ../eval || exit
if [ $1 = "CPU" ] ; then
run_ascend "$@"
fi;
# launch
python ${BASEPATH}/../eval.py \
--device_target=$1 \
--dataset_path=$2 \
--checkpoint_path=$3 \
&> ../infer.log & # dataset val folder path
--platform=$1 \
--dataset_path=$2 \
--pretrain_ckpt=$3 \
--head_ckpt=$4 \
&> ../eval.log & # dataset val folder path

View File

@ -38,12 +38,14 @@ run_ascend()
mkdir ../train
cd ../train || exit
python ${BASEPATH}/../src/launch.py \
--platform=$1 \
--nproc_per_node=$2 \
--visible_devices=$3 \
--training_script=${BASEPATH}/../train.py \
--dataset_path=$5 \
--pre_trained=$6 \
--device_target=$1 &> ../train.log & # dataset train folder
--train_method=$6 \
--pretrain_ckpt=$7 \
&> ../train.log & # dataset train folder
}
run_gpu()
@ -72,17 +74,43 @@ run_gpu()
export CUDA_VISIBLE_DEVICES="$3"
mpirun -n $2 --allow-run-as-root \
python ${BASEPATH}/../train.py \
--platform=$1 \
--dataset_path=$4 \
--pre_trained=$5 \
--device_target=$1 \
--train_method=$5 \
--pretrain_ckpt=$6 \
&> ../train.log & # dataset train folder
}
if [ $# -gt 6 ] || [ $# -lt 4 ]
run_cpu()
{
if [ ! -d $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
python ${BASEPATH}/../train.py \
--platform=$1 \
--dataset_path=$2 \
--train_method=$3 \
--pretrain_ckpt=$4 \
&> ../train.log & # dataset train folder
}
if [ $# -gt 7 ] || [ $# -lt 4 ]
then
echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH]\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]\n \
CPU: sh run_train.sh CPU [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]\n \
"
exit 1
fi
@ -91,7 +119,8 @@ if [ $1 = "Ascend" ] ; then
run_ascend "$@"
elif [ $1 = "GPU" ] ; then
run_gpu "$@"
elif [ $1 = "CPU" ] ; then
run_cpu "$@"
else
echo "Unsupported device_target."
echo "Unsupported platform."
fi;

View File

@ -38,25 +38,24 @@ def launch_parse_args():
def train_parse_args():
train_parser = argparse.ArgumentParser(description='Image classification trian')
train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path')
train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \
help='run platform, only support CPU, GPU and Ascend')
train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path')
train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \
help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \
train from initialization model")
train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \
for fine tune or incremental learning')
train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
train_parser.add_argument('--train_method', type=str, required=True, choices=("train", "fine_tune", \
"incremental_learn"), help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after \
loading the ckpt, \"train\" to train from initialization model")
train_args = train_parser.parse_args()
return train_args
def eval_parse_args():
eval_parser = argparse.ArgumentParser(description='Image classification eval')
eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path')
eval_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
help='run platform, only support GPU, CPU and Ascend')
eval_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \
eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path')
eval_parser.add_argument('--pretrain_ckpt', type=str, required=True, help='Pretrained checkpoint path \
for fine tune or incremental learning')
eval_parser.add_argument('--head_ckpt', type=str, default=None, help='Pretrained checkpoint path \
for fine tune or incremental learning')

View File

@ -37,7 +37,8 @@ def set_config(args):
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 20,
"save_checkpoint_path": "./checkpoint",
"platform": args.platform
"platform": args.platform,
"run_distribute": False
})
config_gpu = ed({
"num_classes": 1000,

View File

@ -38,17 +38,17 @@ def main():
for rank_id in range(0, args.nproc_per_node):
os.chdir(cur_path)
device_id = visible_devices[rank_id]
device_dir = os.path.join(cur_path, 'device{}'.format(rank_id))
rank_dir = os.path.join(cur_path, f'rank{rank_id}')
env['RANK_ID'] = str(rank_id)
env['DEVICE_ID'] = str(device_id)
if os.path.exists(device_dir):
shutil.rmtree(device_dir)
os.mkdir(device_dir)
os.chdir(device_dir)
if os.path.exists(rank_dir):
shutil.rmtree(rank_dir)
os.mkdir(rank_dir)
os.chdir(rank_dir)
cmd = [sys.executable, '-u']
cmd.append(args.training_script)
cmd.extend(args.training_script_args)
log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w')
log_file = open(f'{rank_dir}/log{rank_id}.log', 'w')
process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env)
processes.append(process)
cmds.append(cmd)

View File

@ -119,20 +119,9 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True):
for param in network.get_parameters():
param.requires_grad = False
def define_net(args, config):
backbone_net = MobileNetV2Backbone(platform=args.platform)
def define_net(config):
backbone_net = MobileNetV2Backbone(platform=config.platform)
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
net = mobilenet_v2(backbone_net, head_net)
# load the ckpt file to the network for fine tune or incremental leaning
if args.pretrain_ckpt:
if args.train_method == "fine_tune":
load_ckpt(net, args.pretrain_ckpt)
elif args.train_method == "incremental_learn":
load_ckpt(backbone_net, args.pretrain_ckpt, trainable=False)
elif args.train_method == "train":
pass
else:
raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None")
return backbone_net, head_net, net

View File

@ -23,6 +23,7 @@ from mindspore.common import dtype as mstype
from mindspore.train.model import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import get_rank, init, get_group_size
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from src.models import Monitor
@ -58,8 +59,8 @@ def context_device_init(config):
if config.run_distribute:
context.set_auto_parallel_context(device_num=config.rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, gradients_mean=True,
all_reduce_fusion_config=[140])
parameter_broadcast=True, mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
raise ValueError("Only support CPU, GPU and Ascend.")

View File

@ -35,7 +35,7 @@ from src.config import set_config
from src.args import train_parse_args
from src.utils import context_device_init, switch_precision, config_ckpoint, set_seed
from src.models import CrossEntropyWithLabelSmooth, define_net
from src.models import CrossEntropyWithLabelSmooth, define_net, load_ckpt
set_seed(1)
@ -50,7 +50,18 @@ if __name__ == '__main__':
context_device_init(config)
# define network
backbone_net, head_net, net = define_net(args_opt, config)
backbone_net, head_net, net = define_net(config)
# load the ckpt file to the network for fine tune or incremental leaning
if args_opt.pretrain_ckpt:
if args_opt.train_method == "fine_tune":
load_ckpt(net, args_opt.pretrain_ckpt)
elif args_opt.train_method == "incremental_learn":
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
elif args_opt.train_method == "train":
pass
else:
raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None")
# CPU only support "incremental_learn"
if args_opt.train_method == "incremental_learn":