!12780 add distribute file for naml

From: @zhao_ting_v
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-03-02 16:02:28 +08:00 committed by Gitee
commit f192732e8d
5 changed files with 97 additions and 10 deletions

View File

@ -74,7 +74,11 @@ You can download the dataset and put the directory in structure as follows:
You can start training using python or shell scripts. The usage of shell scripts as follows: You can start training using python or shell scripts. The usage of shell scripts as follows:
```shell ```shell
# train standalone
bash run_train.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] bash run_train.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH]
# train distribute
bash run_distribute_train.sh [PLATFORM] [DEVICE_NUM] [DATASET] [DATASET_PATH] [RANK_TABLE_FILE]
# evaluation
bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PATH] bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PATH]
``` ```
@ -83,6 +87,7 @@ bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PAT
- `DATASET` MIND dataset, support large, small and demo. - `DATASET` MIND dataset, support large, small and demo.
- `DATASET_PATH` is the dataset path, the structure as [Dataset](#dataset). - `DATASET_PATH` is the dataset path, the structure as [Dataset](#dataset).
- `CHECKPOINT_PATH` is a pre-trained checkpoint path. - `CHECKPOINT_PATH` is a pre-trained checkpoint path.
- `RANK_TABLE_FILE` is HCCL configuration file when running on Ascend.
## [Model Export](#contents) ## [Model Export](#contents)

View File

@ -0,0 +1,52 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_train.sh [PLATFORM] [DEVICE_NUM] [DATASET] [DATASET_PATH] [RANK_TABLE_FILE]"
echo "for example: bash run_distribute_train.sh Ascend 8 large /path/MINDlarge /path/hccl_8p.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
PLATFORM=$1
export RANK_SIZE=$2
DATASET=$3
DATASET_PATH=$4
export RANK_TABLE_FILE=$5
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CHECKPOINT_PATH=${PROJECT_DIR}/checkpoint
cd ${PROJECT_DIR}/.. || exit
for((i=0;i<RANK_SIZE;i++))
do
rm -rf LOG$i
mkdir ./LOG$i
cp ./*.py ./LOG$i
cp -r ./src ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
export DEVICE_ID=$i
python train.py \
--platform=${PLATFORM} \
--device_num=${RANK_SIZE} \
--device_id=${DEVICE_ID} \
--dataset=${DATASET} \
--dataset_path=${DATASET_PATH} \
--save_checkpoint_path=${CHECKPOINT_PATH} \
--weight_decay=False \
--sink_mode=True > log.txt 2>&1 &
cd ..
done

View File

@ -16,7 +16,7 @@
import re import re
import os import os
import random import random
from collections import namedtuple import math
import pickle import pickle
import numpy as np import numpy as np
@ -24,9 +24,6 @@ import mindspore.dataset as ds
ds.config.set_prefetch_size(8) ds.config.set_prefetch_size(8)
NEWS_ITEMS = ['category', 'subcategory', 'title', 'abstract']
News = namedtuple('News', NEWS_ITEMS)
class MINDPreprocess: class MINDPreprocess:
""" """
MIND dataset Preprocess class. MIND dataset Preprocess class.
@ -279,10 +276,43 @@ class EvalCandidateNews(EvalDatasetBase):
nid, label = self.preprocess.impression_list[index] nid, label = self.preprocess.impression_list[index]
return uid, nid, label return uid, nid, label
class DistributedSampler():
"""
sampling the dataset.
def create_dataset(mindpreprocess, batch_size=64): Args:
Returns:
num_samples, number of samples.
"""
def __init__(self, preprocess: MINDPreprocess, rank, group_size, shuffle=True, seed=0):
self.preprocess = preprocess
self.rank = rank
self.group_size = group_size
self.dataset_length = preprocess.total_count
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
self.total_size = self.num_samples * self.group_size
self.shuffle = shuffle
self.seed = seed
def __iter__(self):
if self.shuffle:
self.seed = (self.seed + 1) & 0xffffffff
np.random.seed(self.seed)
indices = np.random.permutation(self.dataset_length).tolist()
else:
indices = list(range(len(self.dataset_length)))
indices += indices[:(self.total_size - len(indices))]
indices = indices[self.rank::self.group_size]
return iter(indices)
def __len__(self):
return self.num_samples
def create_dataset(mindpreprocess, batch_size=64, rank=0, group_size=1):
"""Get generator dataset when training.""" """Get generator dataset when training."""
dataset = ds.GeneratorDataset(mindpreprocess, mindpreprocess.column_names, shuffle=True) sampler = DistributedSampler(mindpreprocess, rank, group_size, shuffle=True)
dataset = ds.GeneratorDataset(mindpreprocess, mindpreprocess.column_names, sampler=sampler)
dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset return dataset

View File

@ -16,6 +16,7 @@
import argparse import argparse
import ast import ast
import os import os
import math
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore import context from mindspore import context
from mindspore.communication.management import init, get_rank from mindspore.communication.management import init, get_rank
@ -52,7 +53,6 @@ def get_args(phase):
parser.add_argument('--batch_size', type=int, default=64, help='size of each batch') parser.add_argument('--batch_size', type=int, default=64, help='size of each batch')
# Training specifications # Training specifications
if phase == "train": if phase == "train":
parser.add_argument('--train_dataset_path', type=str, default=None, help='training set directory')
parser.add_argument('--epochs', type=int, default=None, help='number of epochs for training') parser.add_argument('--epochs', type=int, default=None, help='number of epochs for training')
parser.add_argument('--lr', type=float, default=None, help='learning rate') parser.add_argument('--lr', type=float, default=None, help='learning rate')
parser.add_argument('--beta1', type=float, default=0.9, help='ADAM beta1') parser.add_argument('--beta1', type=float, default=0.9, help='ADAM beta1')
@ -72,7 +72,6 @@ def get_args(phase):
help="Save checkpoint path, default is checkpoint.") help="Save checkpoint path, default is checkpoint.")
parser.add_argument('--dropout_ratio', type=float, default=0.2, help='ratio of dropout') parser.add_argument('--dropout_ratio', type=float, default=0.2, help='ratio of dropout')
if phase == "eval": if phase == "eval":
parser.add_argument('--eval_dataset_path', type=str, default=None)
parser.add_argument('--neg_sample', type=int, default=-1, \ parser.add_argument('--neg_sample', type=int, default=-1, \
help='number of negative samples in negative sampling') help='number of negative samples in negative sampling')
if phase == "export": if phase == "export":
@ -100,7 +99,7 @@ def get_args(phase):
args.n_sub_categories = cfg.n_sub_categories args.n_sub_categories = cfg.n_sub_categories
args.n_words = cfg.n_words args.n_words = cfg.n_words
if phase == "train": if phase == "train":
args.epochs = cfg.epochs if args.epochs is None else args.epochs args.epochs = cfg.epochs if args.epochs is None else args.epochs * math.ceil(args.device_num ** 0.5)
args.lr = cfg.lr if args.lr is None else args.lr args.lr = cfg.lr if args.lr is None else args.lr
args.print_times = cfg.print_times if args.print_times is None else args.print_times args.print_times = cfg.print_times if args.print_times is None else args.print_times
args.embedding_file = cfg.embedding_file.format(args.dataset_path) args.embedding_file = cfg.embedding_file.format(args.dataset_path)

View File

@ -34,7 +34,8 @@ if __name__ == '__main__':
if args.checkpoint_path is not None: if args.checkpoint_path is not None:
load_checkpoint(args.pretrain_checkpoint, net_with_loss) load_checkpoint(args.pretrain_checkpoint, net_with_loss)
mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path) mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path)
dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size) dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size, rank=args.rank,
group_size=args.device_num)
args.dataset_size = dataset.get_dataset_size() args.dataset_size = dataset.get_dataset_size()
args.print_times = min(args.dataset_size, args.print_times) args.print_times = min(args.dataset_size, args.print_times)
if args.weight_decay: if args.weight_decay: