!9811 fix mobilentv2 CPU not support full training

From: @zhao_ting_v
Reviewed-by: @c_34,@wuxuejian
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2020-12-11 16:27:52 +08:00 committed by Gitee
commit c329ed4d27
8 changed files with 40 additions and 38 deletions

View File

@ -91,6 +91,12 @@ You can start training using python or shell scripts. The usage of shell scripts
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]
- CPU: sh run_trian.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] - CPU: sh run_trian.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]
> RANK_TABLE_FILE is HCCL configuration file when running on Ascend.
> The common restrictions on using the distributed service are as follows. For details, see the HCCL documentation.
>
> - In a single-node system, a cluster of 1, 2, 4, or 8 devices is supported. In a multi-node system, a cluster of 8 x N devices is supported.
> - Each host has four devices numbered 0 to 3 and four devices numbered 4 to 7 deployed on two different networks. During training of 2 or 4 devices, the devices must be connected and clusters cannot be created across networks.
### Launch ### Launch
```shell ```shell

View File

@ -100,6 +100,12 @@ MobileNetV2总体网络架构如下
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]
- CPU: sh run_trian.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] - CPU: sh run_trian.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]
> RANK_TABLE_FILE 是在Ascned上运行分布式任务时HCCL的配置文件
> 我们列出使用分布式服务常见的使用限制详细的可以查看HCCL对应的使用文档。
>
> - 单机场景下支持1、2、4、8卡设备集群多机场景下支持8*n卡设备集群。
> - 每台机器的0-3卡和4-7卡各为1个组网2卡和4卡训练时卡必须相连且不支持跨组网创建集群。
### 启动 ### 启动
```shell ```shell

View File

@ -29,6 +29,13 @@ run_ascend()
fi fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd) BASEPATH=$(cd "`dirname $0`" || exit; pwd)
VISIABLE_DEVICES=$3
IFS="," read -r -a CANDIDATE_DEVICE <<< "$VISIABLE_DEVICES"
if [ ${#CANDIDATE_DEVICE[@]} -ne $2 ]
then
echo "error: DEVICE_NUM=$2 is not equal to the length of VISIABLE_DEVICES=$3"
exit 1
fi
export PYTHONPATH=${BASEPATH}:$PYTHONPATH export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export RANK_TABLE_FILE=$4 export RANK_TABLE_FILE=$4
export RANK_SIZE=$2 export RANK_SIZE=$2
@ -40,7 +47,7 @@ run_ascend()
cd ../train || exit cd ../train || exit
for((i=0; i<${RANK_SIZE}; i++)) for((i=0; i<${RANK_SIZE}; i++))
do do
export DEVICE_ID=$i export DEVICE_ID=${CANDIDATE_DEVICE[i]}
export RANK_ID=$i export RANK_ID=$i
rm -rf ./rank$i rm -rf ./rank$i
mkdir ./rank$i mkdir ./rank$i

View File

@ -16,26 +16,6 @@
import argparse import argparse
import ast import ast
def launch_parse_args():
launch_parser = argparse.ArgumentParser(description="mindspore distributed training launch helper utilty \
that will spawn up multiple distributed processes")
launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
help='run platform, only support GPU, CPU and Ascend')
launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(1, 2, 3, 4, 5, 6, 7, 8), \
help="The number of processes to launch on each node, for D training, this is recommended to be set \
to the number of D in your system so that each process can be bound to a single D.")
launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the \
visible devices sequentially")
launch_parser.add_argument("--training_script", type=str, default="./train.py", help="The full path to \
the single D training program/script to be launched in parallel, followed by all the arguments for \
the training script")
launch_args, unknown = launch_parser.parse_known_args()
launch_args.training_script_args = unknown
launch_args.training_script_args += ["--platform", launch_args.platform]
return launch_args
def train_parse_args(): def train_parse_args():
train_parser = argparse.ArgumentParser(description='Image classification trian') train_parser = argparse.ArgumentParser(description='Image classification trian')
train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \ train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \
@ -48,6 +28,8 @@ def train_parse_args():
train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
train_args = train_parser.parse_args() train_args = train_parser.parse_args()
train_args.is_training = True train_args.is_training = True
if train_args.platform == "CPU":
train_args.run_distribute = False
return train_args return train_args
def eval_parse_args(): def eval_parse_args():

View File

@ -40,6 +40,7 @@ def set_config(args):
"keep_checkpoint_max": 20, "keep_checkpoint_max": 20,
"save_checkpoint_path": "./", "save_checkpoint_path": "./",
"platform": args.platform, "platform": args.platform,
"run_distribute": args.run_distribute,
"activation": "Softmax", "activation": "Softmax",
"export_format": "MINDIR", "export_format": "MINDIR",
"export_file": "mobilenetv2" "export_file": "mobilenetv2"

View File

@ -331,7 +331,7 @@ class MobileNetV2Combine(nn.Cell):
Tensor, output tensor. Tensor, output tensor.
Examples: Examples:
>>> MobileNetV2(num_classes=1000) >>> MobileNetV2Combine(backbone, head)
""" """
def __init__(self, backbone, head): def __init__(self, backbone, head):

View File

@ -114,6 +114,13 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True):
incremental_learning or not incremental_learning or not
""" """
param_dict = load_checkpoint(pretrain_ckpt_path) param_dict = load_checkpoint(pretrain_ckpt_path)
if hasattr(network, "head"):
head_param = network.head.parameters_dict()
for k, v in head_param.items():
if param_dict[k].shape != v.shape:
param_dict.pop(k)
param_dict.pop(f"moments.{k}")
print(f"Filter {k} don't load weights from checkpoint.")
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
if not trainable: if not trainable:
for param in network.get_parameters(): for param in network.get_parameters():

View File

@ -53,21 +53,14 @@ if __name__ == '__main__':
# define network # define network
backbone_net, head_net, net = define_net(config, args_opt.is_training) backbone_net, head_net, net = define_net(config, args_opt.is_training)
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config)
if args_opt.pretrain_ckpt != "" and args_opt.freeze_layer == "backbone": step_size = dataset.get_dataset_size()
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False) if args_opt.pretrain_ckpt:
step_size = extract_features(backbone_net, args_opt.dataset_path, config) if args_opt.freeze_layer == "backbone":
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
else: step_size = extract_features(backbone_net, args_opt.dataset_path, config)
if args_opt.platform == "CPU": else:
raise ValueError("CPU only support fine tune the head net, doesn't support fine tune the all net") load_ckpt(net, args_opt.pretrain_ckpt)
if args_opt.pretrain_ckpt:
load_ckpt(backbone_net, args_opt.pretrain_ckpt)
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config)
step_size = dataset.get_dataset_size()
if step_size == 0: if step_size == 0:
raise ValueError("The step_size of dataset is zero. Check if the images' count of train dataset is more \ raise ValueError("The step_size of dataset is zero. Check if the images' count of train dataset is more \
than batch_size in config.py") than batch_size in config.py")
@ -93,7 +86,7 @@ if __name__ == '__main__':
total_epochs=epoch_size, total_epochs=epoch_size,
steps_per_epoch=step_size)) steps_per_epoch=step_size))
if args_opt.pretrain_ckpt == "" or args_opt.freeze_layer == "none": if args_opt.pretrain_ckpt == "" or args_opt.freeze_layer != "backbone":
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \ opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \
config.weight_decay, config.loss_scale) config.weight_decay, config.loss_scale)