forked from mindspore-Ecosystem/mindspore
!12780 add distribute file for naml
From: @zhao_ting_v Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejian
This commit is contained in:
commit
f192732e8d
|
@ -74,7 +74,11 @@ You can download the dataset and put the directory in structure as follows:
|
|||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
```shell
|
||||
# train standalone
|
||||
bash run_train.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH]
|
||||
# train distribute
|
||||
bash run_distribute_train.sh [PLATFORM] [DEVICE_NUM] [DATASET] [DATASET_PATH] [RANK_TABLE_FILE]
|
||||
# evaluation
|
||||
bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
|
@ -83,6 +87,7 @@ bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PAT
|
|||
- `DATASET` MIND dataset, support large, small and demo.
|
||||
- `DATASET_PATH` is the dataset path, the structure as [Dataset](#dataset).
|
||||
- `CHECKPOINT_PATH` is a pre-trained checkpoint path.
|
||||
- `RANK_TABLE_FILE` is HCCL configuration file when running on Ascend.
|
||||
|
||||
## [Model Export](#contents)
|
||||
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distribute_train.sh [PLATFORM] [DEVICE_NUM] [DATASET] [DATASET_PATH] [RANK_TABLE_FILE]"
|
||||
echo "for example: bash run_distribute_train.sh Ascend 8 large /path/MINDlarge /path/hccl_8p.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
PLATFORM=$1
|
||||
export RANK_SIZE=$2
|
||||
DATASET=$3
|
||||
DATASET_PATH=$4
|
||||
export RANK_TABLE_FILE=$5
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
CHECKPOINT_PATH=${PROJECT_DIR}/checkpoint
|
||||
cd ${PROJECT_DIR}/.. || exit
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp ./*.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python train.py \
|
||||
--platform=${PLATFORM} \
|
||||
--device_num=${RANK_SIZE} \
|
||||
--device_id=${DEVICE_ID} \
|
||||
--dataset=${DATASET} \
|
||||
--dataset_path=${DATASET_PATH} \
|
||||
--save_checkpoint_path=${CHECKPOINT_PATH} \
|
||||
--weight_decay=False \
|
||||
--sink_mode=True > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
|
@ -16,7 +16,7 @@
|
|||
import re
|
||||
import os
|
||||
import random
|
||||
from collections import namedtuple
|
||||
import math
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
|
@ -24,9 +24,6 @@ import mindspore.dataset as ds
|
|||
|
||||
ds.config.set_prefetch_size(8)
|
||||
|
||||
NEWS_ITEMS = ['category', 'subcategory', 'title', 'abstract']
|
||||
News = namedtuple('News', NEWS_ITEMS)
|
||||
|
||||
class MINDPreprocess:
|
||||
"""
|
||||
MIND dataset Preprocess class.
|
||||
|
@ -279,10 +276,43 @@ class EvalCandidateNews(EvalDatasetBase):
|
|||
nid, label = self.preprocess.impression_list[index]
|
||||
return uid, nid, label
|
||||
|
||||
class DistributedSampler():
|
||||
"""
|
||||
sampling the dataset.
|
||||
|
||||
def create_dataset(mindpreprocess, batch_size=64):
|
||||
Args:
|
||||
Returns:
|
||||
num_samples, number of samples.
|
||||
"""
|
||||
def __init__(self, preprocess: MINDPreprocess, rank, group_size, shuffle=True, seed=0):
|
||||
self.preprocess = preprocess
|
||||
self.rank = rank
|
||||
self.group_size = group_size
|
||||
self.dataset_length = preprocess.total_count
|
||||
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
|
||||
self.total_size = self.num_samples * self.group_size
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.seed = (self.seed + 1) & 0xffffffff
|
||||
np.random.seed(self.seed)
|
||||
indices = np.random.permutation(self.dataset_length).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset_length)))
|
||||
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
indices = indices[self.rank::self.group_size]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def create_dataset(mindpreprocess, batch_size=64, rank=0, group_size=1):
|
||||
"""Get generator dataset when training."""
|
||||
dataset = ds.GeneratorDataset(mindpreprocess, mindpreprocess.column_names, shuffle=True)
|
||||
sampler = DistributedSampler(mindpreprocess, rank, group_size, shuffle=True)
|
||||
dataset = ds.GeneratorDataset(mindpreprocess, mindpreprocess.column_names, sampler=sampler)
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
return dataset
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import argparse
|
||||
import ast
|
||||
import os
|
||||
import math
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank
|
||||
|
@ -52,7 +53,6 @@ def get_args(phase):
|
|||
parser.add_argument('--batch_size', type=int, default=64, help='size of each batch')
|
||||
# Training specifications
|
||||
if phase == "train":
|
||||
parser.add_argument('--train_dataset_path', type=str, default=None, help='training set directory')
|
||||
parser.add_argument('--epochs', type=int, default=None, help='number of epochs for training')
|
||||
parser.add_argument('--lr', type=float, default=None, help='learning rate')
|
||||
parser.add_argument('--beta1', type=float, default=0.9, help='ADAM beta1')
|
||||
|
@ -72,7 +72,6 @@ def get_args(phase):
|
|||
help="Save checkpoint path, default is checkpoint.")
|
||||
parser.add_argument('--dropout_ratio', type=float, default=0.2, help='ratio of dropout')
|
||||
if phase == "eval":
|
||||
parser.add_argument('--eval_dataset_path', type=str, default=None)
|
||||
parser.add_argument('--neg_sample', type=int, default=-1, \
|
||||
help='number of negative samples in negative sampling')
|
||||
if phase == "export":
|
||||
|
@ -100,7 +99,7 @@ def get_args(phase):
|
|||
args.n_sub_categories = cfg.n_sub_categories
|
||||
args.n_words = cfg.n_words
|
||||
if phase == "train":
|
||||
args.epochs = cfg.epochs if args.epochs is None else args.epochs
|
||||
args.epochs = cfg.epochs if args.epochs is None else args.epochs * math.ceil(args.device_num ** 0.5)
|
||||
args.lr = cfg.lr if args.lr is None else args.lr
|
||||
args.print_times = cfg.print_times if args.print_times is None else args.print_times
|
||||
args.embedding_file = cfg.embedding_file.format(args.dataset_path)
|
||||
|
|
|
@ -34,7 +34,8 @@ if __name__ == '__main__':
|
|||
if args.checkpoint_path is not None:
|
||||
load_checkpoint(args.pretrain_checkpoint, net_with_loss)
|
||||
mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path)
|
||||
dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size)
|
||||
dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size, rank=args.rank,
|
||||
group_size=args.device_num)
|
||||
args.dataset_size = dataset.get_dataset_size()
|
||||
args.print_times = min(args.dataset_size, args.print_times)
|
||||
if args.weight_decay:
|
||||
|
|
Loading…
Reference in New Issue