!11490 Add LSTM Ascend distribute train

From: @ttudu
Reviewed-by: @c_34,@guoqi1024
Signed-off-by: @guoqi1024
This commit is contained in:
mindspore-ci-bot 2021-01-22 11:24:37 +08:00 committed by Gitee
commit 44cd679a5f
5 changed files with 94 additions and 21 deletions

View File

@ -22,11 +22,9 @@ import numpy as np
from src.config import lstm_cfg, lstm_cfg_ascend from src.config import lstm_cfg, lstm_cfg_ascend
from src.dataset import lstm_create_dataset, convert_to_mindrecord from src.dataset import lstm_create_dataset, convert_to_mindrecord
from src.lr_schedule import get_lr
from src.lstm import SentimentNet from src.lstm import SentimentNet
from mindspore import Tensor, nn, Model, context from mindspore import Tensor, nn, Model, context
from mindspore.nn import Accuracy from mindspore.nn import Accuracy, Recall, F1
from mindspore.train.callback import LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
if __name__ == '__main__': if __name__ == '__main__':
@ -79,20 +77,8 @@ if __name__ == '__main__':
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
ds_eval = lstm_create_dataset(args.preprocess_path, cfg.batch_size, training=False) ds_eval = lstm_create_dataset(args.preprocess_path, cfg.batch_size, training=False)
if cfg.dynamic_lr:
lr = Tensor(get_lr(global_step=cfg.global_step,
lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max,
warmup_epochs=cfg.warmup_epochs,
total_epochs=cfg.num_epochs,
steps_per_epoch=ds_eval.get_dataset_size(),
lr_adjust_epoch=cfg.lr_adjust_epoch))
else:
lr = cfg.learning_rate
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum) model = Model(network, loss, metrics={'acc': Accuracy(), 'recall': Recall(), 'f1': F1()})
loss_cb = LossMonitor()
model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)

View File

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_train_ascend.sh RANK_TABLE_FILE DEVICE_NUM ACLIMDB_DIR GLOVE_DIR"
echo "for example: bash run_distribute_train_ascend.sh /path/hccl.json 8 /path/aclimdb /path/glove"
echo "It is better to use absolute path."
echo "=============================================================================================================="
ROOT_PATH=`pwd`
export RANK_TABLE_FILE=$1
RANK_SIZE=$2
ACLIMDB_DIR=$3
GLOVE_DIR=$4
for((i=0;i<${RANK_SIZE};i++));
do
rm ${ROOT_PATH}/device$i/ -rf
mkdir ${ROOT_PATH}/device$i
cd ${ROOT_PATH}/device$i || exit
cp ../../*.py ./
cp -r ../../src ./
export RANK_ID=$i
export DEVICE_ID=$i
python train.py \
--device_target="Ascend" \
--aclimdb_path=$ACLIMDB_DIR \
--glove_path=$GLOVE_DIR \
--distribute=true \
--device_num=$RANK_SIZE \
--preprocess=true \
--preprocess_path=./preprocess > log.txt 2>&1 &
done

View File

@ -33,7 +33,7 @@ lstm_cfg = edict({
'keep_checkpoint_max': 10 'keep_checkpoint_max': 10
}) })
# LSTM CONFIG IN ASCEND # LSTM CONFIG IN ASCEND for 1p training
lstm_cfg_ascend = edict({ lstm_cfg_ascend = edict({
'num_classes': 2, 'num_classes': 2,
'momentum': 0.9, 'momentum': 0.9,
@ -53,3 +53,24 @@ lstm_cfg_ascend = edict({
'warmup_epochs': 1, 'warmup_epochs': 1,
'global_step': 0 'global_step': 0
}) })
# LSTM CONFIG IN ASCEND for 8p training
lstm_cfg_ascend_8p = edict({
'num_classes': 2,
'momentum': 0.9,
'num_epochs': 20,
'batch_size': 64,
'embed_size': 300,
'num_hiddens': 128,
'num_layers': 2,
'bidirectional': True,
'save_checkpoint_steps': 7800,
'keep_checkpoint_max': 10,
'dynamic_lr': True,
'lr_init': 0.05,
'lr_end': 0.01,
'lr_max': 0.3,
'lr_adjust_epoch': 20,
'warmup_epochs': 2,
'global_step': 0
})

View File

@ -24,14 +24,15 @@ from mindspore.mindrecord import FileWriter
from .imdb import ImdbParser from .imdb import ImdbParser
def lstm_create_dataset(data_home, batch_size, repeat_num=1, training=True): def lstm_create_dataset(data_home, batch_size, repeat_num=1, training=True, device_num=1, rank=0):
"""Data operations.""" """Data operations."""
ds.config.set_seed(1) ds.config.set_seed(1)
data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0") data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0")
if not training: if not training:
data_dir = os.path.join(data_home, "aclImdb_test.mindrecord0") data_dir = os.path.join(data_home, "aclImdb_test.mindrecord0")
data_set = ds.MindDataset(data_dir, columns_list=["feature", "label"], num_parallel_workers=4) data_set = ds.MindDataset(data_dir, columns_list=["feature", "label"], num_parallel_workers=4,
num_shards=device_num, shard_id=rank)
# apply map operations on images # apply map operations on images
data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size()) data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())

View File

@ -20,7 +20,7 @@ import os
import numpy as np import numpy as np
from src.config import lstm_cfg, lstm_cfg_ascend from src.config import lstm_cfg, lstm_cfg_ascend, lstm_cfg_ascend_8p
from src.dataset import convert_to_mindrecord from src.dataset import convert_to_mindrecord
from src.dataset import lstm_create_dataset from src.dataset import lstm_create_dataset
from src.lr_schedule import get_lr from src.lr_schedule import get_lr
@ -29,6 +29,8 @@ from mindspore import Tensor, nn, Model, context
from mindspore.nn import Accuracy from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train.serialization import load_param_into_net, load_checkpoint from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MindSpore LSTM Example') parser = argparse.ArgumentParser(description='MindSpore LSTM Example')
@ -46,6 +48,9 @@ if __name__ == '__main__':
help='the pretrained checkpoint file path.') help='the pretrained checkpoint file path.')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['GPU', 'CPU', 'Ascend'], parser.add_argument('--device_target', type=str, default="Ascend", choices=['GPU', 'CPU', 'Ascend'],
help='the target device to run, support "GPU", "CPU". Default: "Ascend".') help='the target device to run, support "GPU", "CPU". Default: "Ascend".')
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
help="Run distribute, default is false.")
args = parser.parse_args() args = parser.parse_args()
context.set_context( context.set_context(
@ -53,8 +58,20 @@ if __name__ == '__main__':
save_graphs=False, save_graphs=False,
device_target=args.device_target) device_target=args.device_target)
rank = 0
device_num = 1
if args.device_target == 'Ascend': if args.device_target == 'Ascend':
cfg = lstm_cfg_ascend cfg = lstm_cfg_ascend
if args.distribute == "true":
cfg = lstm_cfg_ascend_8p
init()
device_num = args.device_num
rank = get_rank()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
else: else:
cfg = lstm_cfg cfg = lstm_cfg
@ -82,7 +99,7 @@ if __name__ == '__main__':
if args.pre_trained: if args.pre_trained:
load_param_into_net(network, load_checkpoint(args.pre_trained)) load_param_into_net(network, load_checkpoint(args.pre_trained))
ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1) ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1, device_num=device_num, rank=rank)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
if cfg.dynamic_lr: if cfg.dynamic_lr: