forked from mindspore-Ecosystem/mindspore
!12780 add distribute file for naml
From: @zhao_ting_v Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejian
This commit is contained in:
commit
f192732e8d
|
@ -74,7 +74,11 @@ You can download the dataset and put the directory in structure as follows:
|
||||||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
# train standalone
|
||||||
bash run_train.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH]
|
bash run_train.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH]
|
||||||
|
# train distribute
|
||||||
|
bash run_distribute_train.sh [PLATFORM] [DEVICE_NUM] [DATASET] [DATASET_PATH] [RANK_TABLE_FILE]
|
||||||
|
# evaluation
|
||||||
bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PATH]
|
bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PATH]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -83,6 +87,7 @@ bash run_eval.sh [PLATFORM] [DEVICE_ID] [DATASET] [DATASET_PATH] [CHECKPOINT_PAT
|
||||||
- `DATASET` MIND dataset, support large, small and demo.
|
- `DATASET` MIND dataset, support large, small and demo.
|
||||||
- `DATASET_PATH` is the dataset path, the structure as [Dataset](#dataset).
|
- `DATASET_PATH` is the dataset path, the structure as [Dataset](#dataset).
|
||||||
- `CHECKPOINT_PATH` is a pre-trained checkpoint path.
|
- `CHECKPOINT_PATH` is a pre-trained checkpoint path.
|
||||||
|
- `RANK_TABLE_FILE` is HCCL configuration file when running on Ascend.
|
||||||
|
|
||||||
## [Model Export](#contents)
|
## [Model Export](#contents)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash run_distribute_train.sh [PLATFORM] [DEVICE_NUM] [DATASET] [DATASET_PATH] [RANK_TABLE_FILE]"
|
||||||
|
echo "for example: bash run_distribute_train.sh Ascend 8 large /path/MINDlarge /path/hccl_8p.json"
|
||||||
|
echo "It is better to use absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
PLATFORM=$1
|
||||||
|
export RANK_SIZE=$2
|
||||||
|
DATASET=$3
|
||||||
|
DATASET_PATH=$4
|
||||||
|
export RANK_TABLE_FILE=$5
|
||||||
|
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||||
|
CHECKPOINT_PATH=${PROJECT_DIR}/checkpoint
|
||||||
|
cd ${PROJECT_DIR}/.. || exit
|
||||||
|
for((i=0;i<RANK_SIZE;i++))
|
||||||
|
do
|
||||||
|
rm -rf LOG$i
|
||||||
|
mkdir ./LOG$i
|
||||||
|
cp ./*.py ./LOG$i
|
||||||
|
cp -r ./src ./LOG$i
|
||||||
|
cd ./LOG$i || exit
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
python train.py \
|
||||||
|
--platform=${PLATFORM} \
|
||||||
|
--device_num=${RANK_SIZE} \
|
||||||
|
--device_id=${DEVICE_ID} \
|
||||||
|
--dataset=${DATASET} \
|
||||||
|
--dataset_path=${DATASET_PATH} \
|
||||||
|
--save_checkpoint_path=${CHECKPOINT_PATH} \
|
||||||
|
--weight_decay=False \
|
||||||
|
--sink_mode=True > log.txt 2>&1 &
|
||||||
|
cd ..
|
||||||
|
done
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from collections import namedtuple
|
import math
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -24,9 +24,6 @@ import mindspore.dataset as ds
|
||||||
|
|
||||||
ds.config.set_prefetch_size(8)
|
ds.config.set_prefetch_size(8)
|
||||||
|
|
||||||
NEWS_ITEMS = ['category', 'subcategory', 'title', 'abstract']
|
|
||||||
News = namedtuple('News', NEWS_ITEMS)
|
|
||||||
|
|
||||||
class MINDPreprocess:
|
class MINDPreprocess:
|
||||||
"""
|
"""
|
||||||
MIND dataset Preprocess class.
|
MIND dataset Preprocess class.
|
||||||
|
@ -279,10 +276,43 @@ class EvalCandidateNews(EvalDatasetBase):
|
||||||
nid, label = self.preprocess.impression_list[index]
|
nid, label = self.preprocess.impression_list[index]
|
||||||
return uid, nid, label
|
return uid, nid, label
|
||||||
|
|
||||||
|
class DistributedSampler():
|
||||||
|
"""
|
||||||
|
sampling the dataset.
|
||||||
|
|
||||||
def create_dataset(mindpreprocess, batch_size=64):
|
Args:
|
||||||
|
Returns:
|
||||||
|
num_samples, number of samples.
|
||||||
|
"""
|
||||||
|
def __init__(self, preprocess: MINDPreprocess, rank, group_size, shuffle=True, seed=0):
|
||||||
|
self.preprocess = preprocess
|
||||||
|
self.rank = rank
|
||||||
|
self.group_size = group_size
|
||||||
|
self.dataset_length = preprocess.total_count
|
||||||
|
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
|
||||||
|
self.total_size = self.num_samples * self.group_size
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.shuffle:
|
||||||
|
self.seed = (self.seed + 1) & 0xffffffff
|
||||||
|
np.random.seed(self.seed)
|
||||||
|
indices = np.random.permutation(self.dataset_length).tolist()
|
||||||
|
else:
|
||||||
|
indices = list(range(len(self.dataset_length)))
|
||||||
|
|
||||||
|
indices += indices[:(self.total_size - len(indices))]
|
||||||
|
indices = indices[self.rank::self.group_size]
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def create_dataset(mindpreprocess, batch_size=64, rank=0, group_size=1):
|
||||||
"""Get generator dataset when training."""
|
"""Get generator dataset when training."""
|
||||||
dataset = ds.GeneratorDataset(mindpreprocess, mindpreprocess.column_names, shuffle=True)
|
sampler = DistributedSampler(mindpreprocess, rank, group_size, shuffle=True)
|
||||||
|
dataset = ds.GeneratorDataset(mindpreprocess, mindpreprocess.column_names, sampler=sampler)
|
||||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.communication.management import init, get_rank
|
from mindspore.communication.management import init, get_rank
|
||||||
|
@ -52,7 +53,6 @@ def get_args(phase):
|
||||||
parser.add_argument('--batch_size', type=int, default=64, help='size of each batch')
|
parser.add_argument('--batch_size', type=int, default=64, help='size of each batch')
|
||||||
# Training specifications
|
# Training specifications
|
||||||
if phase == "train":
|
if phase == "train":
|
||||||
parser.add_argument('--train_dataset_path', type=str, default=None, help='training set directory')
|
|
||||||
parser.add_argument('--epochs', type=int, default=None, help='number of epochs for training')
|
parser.add_argument('--epochs', type=int, default=None, help='number of epochs for training')
|
||||||
parser.add_argument('--lr', type=float, default=None, help='learning rate')
|
parser.add_argument('--lr', type=float, default=None, help='learning rate')
|
||||||
parser.add_argument('--beta1', type=float, default=0.9, help='ADAM beta1')
|
parser.add_argument('--beta1', type=float, default=0.9, help='ADAM beta1')
|
||||||
|
@ -72,7 +72,6 @@ def get_args(phase):
|
||||||
help="Save checkpoint path, default is checkpoint.")
|
help="Save checkpoint path, default is checkpoint.")
|
||||||
parser.add_argument('--dropout_ratio', type=float, default=0.2, help='ratio of dropout')
|
parser.add_argument('--dropout_ratio', type=float, default=0.2, help='ratio of dropout')
|
||||||
if phase == "eval":
|
if phase == "eval":
|
||||||
parser.add_argument('--eval_dataset_path', type=str, default=None)
|
|
||||||
parser.add_argument('--neg_sample', type=int, default=-1, \
|
parser.add_argument('--neg_sample', type=int, default=-1, \
|
||||||
help='number of negative samples in negative sampling')
|
help='number of negative samples in negative sampling')
|
||||||
if phase == "export":
|
if phase == "export":
|
||||||
|
@ -100,7 +99,7 @@ def get_args(phase):
|
||||||
args.n_sub_categories = cfg.n_sub_categories
|
args.n_sub_categories = cfg.n_sub_categories
|
||||||
args.n_words = cfg.n_words
|
args.n_words = cfg.n_words
|
||||||
if phase == "train":
|
if phase == "train":
|
||||||
args.epochs = cfg.epochs if args.epochs is None else args.epochs
|
args.epochs = cfg.epochs if args.epochs is None else args.epochs * math.ceil(args.device_num ** 0.5)
|
||||||
args.lr = cfg.lr if args.lr is None else args.lr
|
args.lr = cfg.lr if args.lr is None else args.lr
|
||||||
args.print_times = cfg.print_times if args.print_times is None else args.print_times
|
args.print_times = cfg.print_times if args.print_times is None else args.print_times
|
||||||
args.embedding_file = cfg.embedding_file.format(args.dataset_path)
|
args.embedding_file = cfg.embedding_file.format(args.dataset_path)
|
||||||
|
|
|
@ -34,7 +34,8 @@ if __name__ == '__main__':
|
||||||
if args.checkpoint_path is not None:
|
if args.checkpoint_path is not None:
|
||||||
load_checkpoint(args.pretrain_checkpoint, net_with_loss)
|
load_checkpoint(args.pretrain_checkpoint, net_with_loss)
|
||||||
mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path)
|
mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path)
|
||||||
dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size)
|
dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size, rank=args.rank,
|
||||||
|
group_size=args.device_num)
|
||||||
args.dataset_size = dataset.get_dataset_size()
|
args.dataset_size = dataset.get_dataset_size()
|
||||||
args.print_times = min(args.dataset_size, args.print_times)
|
args.print_times = min(args.dataset_size, args.print_times)
|
||||||
if args.weight_decay:
|
if args.weight_decay:
|
||||||
|
|
Loading…
Reference in New Issue