forked from mindspore-Ecosystem/mindspore
add fasttext to model_zoo
This commit is contained in:
parent
5671b177f4
commit
edc48b48b8
|
@ -46,6 +46,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
|
|||
- [BERT[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/bert/README.md)
|
||||
- [TinyBERT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/tinybert/README.md)
|
||||
- [GNMT V2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2/README.md)
|
||||
- [FastText](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext/README.md)
|
||||
- [LSTM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm/README.md)
|
||||
- [MASS](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/mass/README.md)
|
||||
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)
|
||||
|
|
|
@ -0,0 +1,267 @@
|
|||
![](https://www.mindspore.cn/static/img/logo.a3e472c9.png)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [FastText](#fasttext)
|
||||
- [Model Structure](#model-structure)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Dataset Preparation](#dataset-preparation)
|
||||
- [Configuration File](#configuration-file)
|
||||
- [Training Process](#training-process)
|
||||
- [Inference Process](#inference-process)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Inference Performance](#inference-performance)
|
||||
- [Random Situation Description](#random-situation-description)
|
||||
- [Others](#others)
|
||||
- [ModelZoo HomePage](#modelzoo-homepage)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# [FastText](#contents)
|
||||
|
||||
FastText is a fast text classification algorithm, which is simple and efficient. It was proposed by Armand
|
||||
Joulin, Tomas Mikolov etc. in the artical "Bag of Tricks for Efficient Text Classification" in 2016. It is similar to
|
||||
CBOW in model architecture, where the middle word is replace by a label. FastText adopts ngram feature as addition feature
|
||||
to get some information about words. It speeds up training and testing while maintaining high percision, and widly used
|
||||
in various tasks of text classification.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1607.01759.pdf): "Bag of Tricks for Efficient Text Classification", 2016, A. Joulin, E. Grave, P. Bojanowski, and T. Mikolov
|
||||
|
||||
# [Model Structure](#contents)
|
||||
|
||||
The FastText model mainly consists of an input layer, hidden layer and output layer, where the input is a sequence of words (text or sentence).
|
||||
The output layer is probability that the words sequence belongs to different categories. The hidden layer is formed by average of multiple word vector.
|
||||
The feature is mapped to the hidden layer through linear transformation, and then mapped to the label from the hidden layer.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network
|
||||
architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
- AG's news topic classification dataset
|
||||
- DBPedia Ontology Classification Dataset
|
||||
- Yelp Review Polarity Dataset
|
||||
|
||||
# [Environment Requirements](#content)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#content)
|
||||
|
||||
After dataset preparation, you can start training and evaluation as follows:
|
||||
|
||||
```bash
|
||||
# run training example
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [TRAIN_DATASET]
|
||||
|
||||
# run distributed training example
|
||||
sh run_distribute_train.sh [TRAIN_DATASET] [RANK_TABLE_PATH]
|
||||
|
||||
# run evaluation example
|
||||
sh run_eval.sh [EVAL_DATASET_PATH] [DATASET_NAME] [MODEL_CKPT]
|
||||
```
|
||||
|
||||
# [Script Description](#content)
|
||||
|
||||
The FastText network script and code result are as follows:
|
||||
|
||||
```text
|
||||
├── fasttext
|
||||
├── README.md // Introduction of FastText model.
|
||||
├── src
|
||||
│ ├──config.py // Configuration instance definition.
|
||||
│ ├──create_dataset.py // Dataset preparation.
|
||||
│ ├──fasttext_model.py // FastText model architecture.
|
||||
│ ├──fasttext_train.py // Use FastText model architecture.
|
||||
│ ├──load_dataset.py // Dataset loader to feed into model.
|
||||
│ ├──lr_scheduler.py // Learning rate scheduler.
|
||||
├── scripts
|
||||
│ ├──run_distributed_train.sh // shell script for distributed train on ascend.
|
||||
│ ├──run_eval.sh // shell script for standalone eval on ascend.
|
||||
│ ├──run_standalone_train.sh // shell script for standalone eval on ascend.
|
||||
├── eval.py // Infer API entry.
|
||||
├── requirements.txt // Requirements of third party package.
|
||||
├── train.py // Train API entry.
|
||||
```
|
||||
|
||||
## [Dataset Preparation](#content)
|
||||
|
||||
- Download the AG's News Topic Classification Dataset, DBPedia Ontology Classification Dataset and Yelp Review Polarity Dataset. Unzip datasets to any path you want.
|
||||
|
||||
- Run the following scripts to do data preprocess and convert the original data to mindrecord for training and evaluation.
|
||||
|
||||
``` bash
|
||||
cd scripts
|
||||
sh creat_dataset.sh [SOURCE_DATASET_PATH] [DATASET_NAME]
|
||||
```
|
||||
|
||||
## [Configuration File](#content)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py. All the datasets are using same parameter name, parameters value could be changed according the needs.
|
||||
|
||||
- Network Parameters
|
||||
|
||||
```text
|
||||
vocab_size # vocabulary size.
|
||||
buckets # bucket sequence length.
|
||||
batch_size # batch size of input dataset.
|
||||
embedding_dims # The size of each embedding vector.
|
||||
num_class # number of labels.
|
||||
epoch # total training epochs.
|
||||
lr # initial learning rate.
|
||||
min_lr # minimum learning rate.
|
||||
warmup_steps # warm up steps.
|
||||
poly_lr_scheduler_power # a value used to calculate decayed learning rate.
|
||||
pretrain_ckpt_dir # pretrain checkpoint direction.
|
||||
keep_ckpt_max # Max ckpt files number.
|
||||
```
|
||||
|
||||
## [Training Process](#content)
|
||||
|
||||
- Start task training on a single device and run the shell script
|
||||
|
||||
```bash
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [DATASET_PATH]
|
||||
```
|
||||
|
||||
- Running scripts for distributed training of FastText. Task training on multiple device and run the following command in bash to be executed in `scripts/`:
|
||||
|
||||
``` bash
|
||||
cd ./scripts
|
||||
sh run_distributed_train.sh [DATASET_PATH] [RANK_TABLE_PATH]
|
||||
```
|
||||
|
||||
## [Inference Process](#content)
|
||||
|
||||
- Running scripts for evaluation of FastText. The commdan as below.
|
||||
|
||||
``` bash
|
||||
cd ./scripts
|
||||
sh run_eval.sh [DATASET_PATH] [DATASET_NAME] [MODEL_CKPT]
|
||||
```
|
||||
|
||||
Note: The `DATASET_PATH` is path to mindrecord. eg. /dataset_path/*.mindrecord
|
||||
|
||||
# [Model Description](#content)
|
||||
|
||||
## [Performance](#content)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | -------------------------------------------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| uploaded Date | 12/21/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | AG's News Topic Classification Dataset |
|
||||
| Training Parameters | epoch=5, batch_size=128 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| outputs | probability |
|
||||
| Speed | 112ms/step (8pcs) |
|
||||
| Total Time | 66s (8pcs) |
|
||||
| Loss | 0.00082 |
|
||||
| Params (M) | 22 |
|
||||
| Checkpoint for inference | 254M (.ckpt file) |
|
||||
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | -------------------------------------------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| uploaded Date | 11/21/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | DBPedia Ontology Classification Dataset |
|
||||
| Training Parameters | epoch=5, batch_size=128 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| outputs | probability |
|
||||
| Speed | 60ms/step (8pcs) |
|
||||
| Total Time | 164s (8pcs) |
|
||||
| Loss | 2.6e-5 |
|
||||
| Params (M) | 106 |
|
||||
| Checkpoint for inference | 1.2G (.ckpt file) |
|
||||
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | -------------------------------------------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| uploaded Date | 11/21/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | Yelp Review Polarity Dataset |
|
||||
| Training Parameters | epoch=5, batch_size=128 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Softmax Cross Entropy |
|
||||
| outputs | probability |
|
||||
| Speed | 74ms/step (8pcs) |
|
||||
| Total Time | 195s (8pcs) |
|
||||
| Loss | 7.7e-4 |
|
||||
| Params (M) | 103 |
|
||||
| Checkpoint for inference | 1.2G (.ckpt file) |
|
||||
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 12/21/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | AG's News Topic Classification Dataset |
|
||||
| batch_size | 128 |
|
||||
| Total Time | 66s |
|
||||
| outputs | label index |
|
||||
| Accuracy | 92.53 |
|
||||
| Model for inference | 254M (.ckpt file) |
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 12/21/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | DBPedia Ontology Classification Dataset |
|
||||
| batch_size | 128 |
|
||||
| Total Time | 164s |
|
||||
| outputs | label index |
|
||||
| Accuracy | 98.6 |
|
||||
| Model for inference | 1.2G (.ckpt file) |
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 12/21/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | Yelp Review Polarity Dataset |
|
||||
| batch_size | 128 |
|
||||
| Total Time | 195s |
|
||||
| outputs | label index |
|
||||
| Accuracy | 95.7 |
|
||||
| Model for inference | 1.2G (.ckpt file) |
|
||||
|
||||
# [Random Situation Description](#content)
|
||||
|
||||
There only one random situation.
|
||||
|
||||
- Initialization of some model weights.
|
||||
|
||||
Some seeds have already been set in train.py to avoid the randomness of weight initialization.
|
||||
|
||||
# [Others](#others)
|
||||
|
||||
This model has been validated in the Ascend environment and is not validated on the CPU and GPU.
|
||||
|
||||
# [ModelZoo HomePage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText for Evaluation"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
from mindspore import context
|
||||
from src.fasttext_model import FastText
|
||||
parser = argparse.ArgumentParser(description='fasttext')
|
||||
parser.add_argument('--data_path', type=str, help='infer dataset path..')
|
||||
parser.add_argument('--data_name', type=str, required=True, default='ag',
|
||||
help='dataset name. eg. ag, dbpedia')
|
||||
parser.add_argument("--model_ckpt", type=str, required=True,
|
||||
help="existed checkpoint address.")
|
||||
args = parser.parse_args()
|
||||
if args.data_name == "ag":
|
||||
from src.config import config_ag as config
|
||||
target_label1 = ['0', '1', '2', '3']
|
||||
elif args.data_name == 'dbpedia':
|
||||
from src.config import config_db as config
|
||||
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
|
||||
elif args.data_name == 'yelp_p':
|
||||
from src.config import config_yelpp as config
|
||||
target_label1 = ['0', '1']
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
class FastTextInferCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of FastText network infer.
|
||||
|
||||
Args:
|
||||
network (nn.Cell): FastText model.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor], predicted_ids
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(FastTextInferCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True)
|
||||
self.log_softmax = nn.LogSoftmax(axis=1)
|
||||
|
||||
def construct(self, src_tokens, src_tokens_lengths):
|
||||
"""construct fasttext infer cell"""
|
||||
prediction = self.network(src_tokens, src_tokens_lengths)
|
||||
predicted_idx = self.log_softmax(prediction)
|
||||
predicted_idx, _ = self.argmax(predicted_idx)
|
||||
|
||||
return predicted_idx
|
||||
|
||||
def load_infer_dataset(batch_size, datafile):
|
||||
"""data loader for infer"""
|
||||
ds = de.MindDataset(datafile, columns_list=['src_tokens', 'src_tokens_length', 'label_idx'])
|
||||
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
ds = ds.map(operations=type_cast_op, input_columns="src_tokens")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="src_tokens_length")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="label_idx")
|
||||
ds = ds.batch(batch_size=batch_size, drop_remainder=True)
|
||||
|
||||
return ds
|
||||
|
||||
def run_fasttext_infer():
|
||||
"""run infer with FastText"""
|
||||
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path)
|
||||
fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class)
|
||||
|
||||
parameter_dict = load_checkpoint(args.model_ckpt)
|
||||
load_param_into_net(fasttext_model, parameter_dict=parameter_dict)
|
||||
|
||||
ft_infer = FastTextInferCell(fasttext_model)
|
||||
|
||||
model = Model(ft_infer)
|
||||
|
||||
predictions = []
|
||||
target_sens = []
|
||||
|
||||
for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
target_sens.append(batch['label_idx'])
|
||||
src_tokens = Tensor(batch['src_tokens'], mstype.int32)
|
||||
src_tokens_length = Tensor(batch['src_tokens_length'], mstype.int32)
|
||||
predicted_idx = model.predict(src_tokens, src_tokens_length)
|
||||
predictions.append(predicted_idx.asnumpy())
|
||||
|
||||
from sklearn.metrics import accuracy_score, classification_report
|
||||
target_sens = np.array(target_sens).flatten()
|
||||
predictions = np.array(predictions).flatten()
|
||||
acc = accuracy_score(target_sens, predictions)
|
||||
|
||||
result_report = classification_report(target_sens, predictions, target_names=target_label1)
|
||||
print("********Accuracy: ", acc)
|
||||
print(result_report)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_fasttext_infer()
|
|
@ -0,0 +1,3 @@
|
|||
spacy
|
||||
sklearn
|
||||
en_core_web_lg
|
|
@ -0,0 +1,83 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh create_dataset.sh SOURCE_DATASET_PATH DATASET_NAME"
|
||||
echo "for example: sh create_dataset.sh /home/workspace/ag_news_csv ag"
|
||||
echo "DATASET_NAME including ag, dbpedia, and yelp_p"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
ulimit -u unlimited
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
SOURCE_DATASET_PATH=$(get_real_path $1)
|
||||
DATASET_NAME=$2
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=5
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ $DATASET_NAME == 'ag' ];
|
||||
then
|
||||
echo "Begin to process ag news data"
|
||||
if [ -d "ag" ];
|
||||
then
|
||||
rm -rf ./ag
|
||||
fi
|
||||
mkdir ./ag
|
||||
cd ./ag || exit
|
||||
echo "start data preprocess for device $DEVICE_ID"
|
||||
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 4 --max_len 467 --bucket [64,128,467] --test_bucket [467]
|
||||
cd ..
|
||||
fi
|
||||
|
||||
if [ $DATASET_NAME == 'dbpedia' ];
|
||||
then
|
||||
echo "Begin to process dbpedia data"
|
||||
if [ -d "dbpedia" ];
|
||||
then
|
||||
rm -rf ./dbpedia
|
||||
fi
|
||||
mkdir ./dbpedia
|
||||
cd ./dbpedia || exit
|
||||
echo "start data preprocess for device $DEVICE_ID"
|
||||
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 14 --max_len 3013 --bucket [128,512,3013] --test_bucket [1120]
|
||||
cd ..
|
||||
fi
|
||||
|
||||
if [ $DATASET_NAME == 'yelp_p' ];
|
||||
then
|
||||
echo "Begin to process ag news data"
|
||||
if [ -d "yelp_p" ];
|
||||
then
|
||||
rm -rf ./yelp_p
|
||||
fi
|
||||
mkdir ./yelp_p
|
||||
cd ./yelp_p || exit
|
||||
echo "start data preprocess for device $DEVICE_ID"
|
||||
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 2 --max_len 2955 --bucket [64,128,256,512,2955] --test_bucket [2955]
|
||||
cd ..
|
||||
fi
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distributed_train.sh DATASET_PATH RANK_TABLE_PATH"
|
||||
echo "for example: sh run_distributed_train.sh /home/workspace/ag /home/workspace/rank_table_file.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET=$(get_real_path $1)
|
||||
echo $DATASET
|
||||
DATANAME=$(basename $DATASET)
|
||||
RANK_TABLE_PATH=$(get_real_path $2)
|
||||
echo $DATANAME
|
||||
if [ ! -d $DATASET ]
|
||||
then
|
||||
echo "Error: DATA_PATH=$DATASET is not a file"
|
||||
exit 1
|
||||
fi
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_PATH
|
||||
|
||||
|
||||
echo $RANK_TABLE_FILE
|
||||
export RANK_SIZE=8
|
||||
export DEVICE_NUM=8
|
||||
|
||||
|
||||
for((i=0;i<=7;i++));
|
||||
do
|
||||
rm -rf ${current_exec_path}/device$i
|
||||
mkdir ${current_exec_path}/device$i
|
||||
cd ${current_exec_path}/device$i
|
||||
cp ../../*.py ./
|
||||
cp -r ../../src ./
|
||||
cp -r ../*.sh ./
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
python ../../train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
|
||||
cd ${current_exec_path}
|
||||
done
|
||||
cd ${current_exec_path}
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_eval.sh DATASET_PATH DATASET_NAME MODEL_CKPT"
|
||||
echo "for example: sh run_eval.sh /home/workspace/ag/test*.mindrecord ag device0/ckpt0/fasttext-5-118.ckpt"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET=$(get_real_path $1)
|
||||
echo $DATASET
|
||||
DATANAME=$2
|
||||
MODEL_CKPT=$(get_real_path $3)
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=5
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../scripts/*.sh ./eval
|
||||
cd ./eval || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python ../../eval.py --data_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_standalone_train.sh DATASET_PATH"
|
||||
echo "for example: sh run_standalone_train.sh /home/workspace/ag"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET=$(get_real_path $1)
|
||||
echo $DATASET
|
||||
DATANAME=$(basename $DATASET)
|
||||
echo $DATANAME
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../scripts/*.sh ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" :===========================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config_yelpp = ed({
|
||||
'vocab_size': 6414979,
|
||||
'buckets': [64, 128, 256, 512, 2955],
|
||||
'batch_size': 128,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 2,
|
||||
'epoch': 5,
|
||||
'lr': 0.02,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 549,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 549,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
||||
|
||||
config_db = ed({
|
||||
'vocab_size': 6596536,
|
||||
'buckets': [128, 512, 3013],
|
||||
'batch_size': 128,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 14,
|
||||
'epoch': 5,
|
||||
'lr': 0.05,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 549,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 548,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
||||
|
||||
config_ag = ed({
|
||||
'vocab_size': 1383812,
|
||||
'buckets': [64, 128, 467],
|
||||
'batch_size': 128,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 4,
|
||||
'epoch': 5,
|
||||
'lr': 0.05,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 115,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 116,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
|
@ -0,0 +1,316 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText data preprocess"""
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import argparse
|
||||
import pprint
|
||||
import ast
|
||||
import html
|
||||
import numpy as np
|
||||
import spacy
|
||||
from sklearn.feature_extraction import FeatureHasher
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
class FastTextDataPreProcess():
|
||||
"""FastText data preprocess"""
|
||||
def __init__(self, train_path,
|
||||
test_file,
|
||||
max_length,
|
||||
class_num,
|
||||
ngram,
|
||||
train_feature_dict,
|
||||
buckets,
|
||||
test_feature_dict,
|
||||
test_bucket,
|
||||
is_hashed,
|
||||
feature_size):
|
||||
self.train_path = train_path
|
||||
self.test_path = test_file
|
||||
self.max_length = max_length
|
||||
self.class_num = class_num
|
||||
self.train_feature_dict = train_feature_dict
|
||||
self.test_feature_dict = test_feature_dict
|
||||
self.test_bucket = test_bucket
|
||||
self.is_hashed = is_hashed
|
||||
self.feature_size = feature_size
|
||||
self.buckets = buckets
|
||||
self.ngram = ngram
|
||||
self.text_greater = '>'
|
||||
self.text_less = '<'
|
||||
self.ngram2idx = dict()
|
||||
self.idx2gram = dict()
|
||||
self.non_str = '\\'
|
||||
self.end_string = ['.', '?', '!']
|
||||
self.ngram2idx['PAD'] = 0
|
||||
self.idx2gram[0] = 'PAD'
|
||||
self.ngram2idx['UNK'] = 1
|
||||
self.idx2gram[1] = 'UNK'
|
||||
self.str_html = re.compile(r'<[^>]+>')
|
||||
|
||||
def load(self):
|
||||
"""data preprocess loader"""
|
||||
train_dataset_list = []
|
||||
test_dataset_list = []
|
||||
spacy_nlp = spacy.load('en_core_web_lg', disable=['parser', 'tagger', 'ner'])
|
||||
spacy_nlp.add_pipe(spacy_nlp.create_pipe('sentencizer'))
|
||||
|
||||
with open(self.train_path, 'r', newline='', encoding='utf-8') as src_file:
|
||||
reader = csv.reader(src_file, delimiter=",", quotechar='"')
|
||||
for _, _pair_sen in enumerate(reader):
|
||||
label_idx = int(_pair_sen[0]) - 1
|
||||
if len(_pair_sen) == 3:
|
||||
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
|
||||
src_text2=_pair_sen[2],
|
||||
spacy_nlp=spacy_nlp,
|
||||
train_mode=True)
|
||||
src_tokens_length = len(src_tokens)
|
||||
elif len(_pair_sen) == 2:
|
||||
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
|
||||
src_text2=None,
|
||||
spacy_nlp=spacy_nlp,
|
||||
train_mode=True)
|
||||
src_tokens_length = len(src_tokens)
|
||||
elif len(_pair_sen) == 4:
|
||||
if _pair_sen[2]:
|
||||
sen_o_t = _pair_sen[1] + ' ' + _pair_sen[2]
|
||||
else:
|
||||
sen_o_t = _pair_sen[1]
|
||||
src_tokens = self.input_preprocess(src_text1=sen_o_t,
|
||||
src_text2=_pair_sen[3],
|
||||
spacy_nlp=spacy_nlp,
|
||||
train_mode=True)
|
||||
src_tokens_length = len(src_tokens)
|
||||
|
||||
train_dataset_list.append([src_tokens, src_tokens_length, label_idx])
|
||||
|
||||
print("Begin to process test data...")
|
||||
with open(self.test_path, 'r', newline='', encoding='utf-8') as test_file:
|
||||
reader2 = csv.reader(test_file, delimiter=",", quotechar='"')
|
||||
for _, _test_sen in enumerate(reader2):
|
||||
label_idx = int(_test_sen[0]) - 1
|
||||
if len(_test_sen) == 3:
|
||||
src_tokens = self.input_preprocess(src_text1=_test_sen[1],
|
||||
src_text2=_test_sen[2],
|
||||
spacy_nlp=spacy_nlp,
|
||||
train_mode=False)
|
||||
src_tokens_length = len(src_tokens)
|
||||
elif len(_test_sen) == 2:
|
||||
src_tokens = self.input_preprocess(src_text1=_test_sen[1],
|
||||
src_text2=None,
|
||||
spacy_nlp=spacy_nlp,
|
||||
train_mode=False)
|
||||
src_tokens_length = len(src_tokens)
|
||||
elif len(_test_sen) == 4:
|
||||
if _test_sen[2]:
|
||||
sen_o_t = _test_sen[1] + ' ' + _test_sen[2]
|
||||
else:
|
||||
sen_o_t = _test_sen[1]
|
||||
src_tokens = self.input_preprocess(src_text1=sen_o_t,
|
||||
src_text2=_test_sen[3],
|
||||
spacy_nlp=spacy_nlp,
|
||||
train_mode=False)
|
||||
src_tokens_length = len(src_tokens)
|
||||
|
||||
test_dataset_list.append([src_tokens, src_tokens_length, label_idx])
|
||||
|
||||
if self.is_hashed:
|
||||
print("Begin to Hashing Trick......")
|
||||
features_num = self.feature_size
|
||||
fh = FeatureHasher(n_features=features_num, alternate_sign=False)
|
||||
print("FeatureHasher features..", features_num)
|
||||
self.hash_trick(fh, train_dataset_list)
|
||||
self.hash_trick(fh, test_dataset_list)
|
||||
print("Hashing Done....")
|
||||
|
||||
# pad train dataset
|
||||
train_dataset_list_length = len(train_dataset_list)
|
||||
test_dataset_list_length = len(test_dataset_list)
|
||||
for l in range(train_dataset_list_length):
|
||||
bucket_length = self._get_bucket_length(train_dataset_list[l][0], self.buckets)
|
||||
while len(train_dataset_list[l][0]) < bucket_length:
|
||||
train_dataset_list[l][0].append(self.ngram2idx['PAD'])
|
||||
train_dataset_list[l][1] = len(train_dataset_list[l][0])
|
||||
# pad test dataset
|
||||
for j in range(test_dataset_list_length):
|
||||
test_bucket_length = self._get_bucket_length(test_dataset_list[j][0], self.test_bucket)
|
||||
while len(test_dataset_list[j][0]) < test_bucket_length:
|
||||
test_dataset_list[j][0].append(self.ngram2idx['PAD'])
|
||||
test_dataset_list[j][1] = len(test_dataset_list[j][0])
|
||||
|
||||
train_example_data = []
|
||||
test_example_data = []
|
||||
for idx in range(train_dataset_list_length):
|
||||
train_example_data.append({
|
||||
"src_tokens": train_dataset_list[idx][0],
|
||||
"src_tokens_length": train_dataset_list[idx][1],
|
||||
"label_idx": train_dataset_list[idx][2],
|
||||
})
|
||||
for key in self.train_feature_dict:
|
||||
if key == train_example_data[idx]['src_tokens_length']:
|
||||
self.train_feature_dict[key].append(train_example_data[idx])
|
||||
for h in range(test_dataset_list_length):
|
||||
test_example_data.append({
|
||||
"src_tokens": test_dataset_list[h][0],
|
||||
"src_tokens_length": test_dataset_list[h][1],
|
||||
"label_idx": test_dataset_list[h][2],
|
||||
})
|
||||
for key in self.test_feature_dict:
|
||||
if key == test_example_data[h]['src_tokens_length']:
|
||||
self.test_feature_dict[key].append(test_example_data[h])
|
||||
print("train vocab size is ", len(self.ngram2idx))
|
||||
|
||||
return self.train_feature_dict, self.test_feature_dict
|
||||
|
||||
def input_preprocess(self, src_text1, src_text2, spacy_nlp, train_mode):
|
||||
"""data preprocess func"""
|
||||
src_text1 = src_text1.strip()
|
||||
if src_text1 and src_text1[-1] not in self.end_string:
|
||||
src_text1 = src_text1 + '.'
|
||||
|
||||
if src_text2:
|
||||
src_text2 = src_text2.strip()
|
||||
sent_describe = src_text1 + ' ' + src_text2
|
||||
else:
|
||||
sent_describe = src_text1
|
||||
if self.non_str in sent_describe:
|
||||
sent_describe = sent_describe.replace(self.non_str, ' ')
|
||||
|
||||
sent_describe = html.unescape(sent_describe)
|
||||
|
||||
if self.text_less in sent_describe and self.text_greater in sent_describe:
|
||||
sent_describe = self.str_html.sub('', sent_describe)
|
||||
|
||||
|
||||
doc = spacy_nlp(sent_describe)
|
||||
bows_token = [token.text for token in doc]
|
||||
|
||||
try:
|
||||
tagged_sent_desc = '<p> ' + ' </s> '.join([s.text for s in doc.sents]) + ' </p>'
|
||||
except ValueError:
|
||||
tagged_sent_desc = '<p> ' + sent_describe + ' </p>'
|
||||
doc = spacy_nlp(tagged_sent_desc)
|
||||
ngrams = self.generate_gram([token.text for token in doc], num=self.ngram)
|
||||
|
||||
bo_ngrams = bows_token + ngrams
|
||||
|
||||
if train_mode is True:
|
||||
for ngms in bo_ngrams:
|
||||
idx = self.ngram2idx.get(ngms)
|
||||
if idx is None:
|
||||
idx = len(self.ngram2idx)
|
||||
self.ngram2idx[ngms] = idx
|
||||
self.idx2gram[idx] = ngms
|
||||
|
||||
processed_out = [self.ngram2idx[ng] if ng in self.ngram2idx else self.ngram2idx['UNK'] for ng in bo_ngrams]
|
||||
|
||||
return processed_out
|
||||
|
||||
def _get_bucket_length(self, x, bts):
|
||||
x_len = len(x)
|
||||
for index in range(1, len(bts)):
|
||||
if bts[index-1] < x_len <= bts[index]:
|
||||
return bts[index]
|
||||
return bts[0]
|
||||
|
||||
def generate_gram(self, words, num=2):
|
||||
|
||||
return [' '.join(words[i: i + num]) for i in range(len(words) - num + 1)]
|
||||
|
||||
def count2dict(self, lst):
|
||||
count_dict = dict()
|
||||
for m in lst:
|
||||
if str(m) in count_dict:
|
||||
count_dict[str(m)] += 1
|
||||
else:
|
||||
count_dict[str(m)] = 1
|
||||
return count_dict
|
||||
|
||||
def hash_trick(self, hashing, input_data):
|
||||
trans = hashing.transform((self.count2dict(e[0]) for e in input_data))
|
||||
for htr, e in zip(trans, input_data):
|
||||
sparse2bow = list()
|
||||
for idc, d in zip(htr.indices, htr.data):
|
||||
for _ in range(int(d)):
|
||||
sparse2bow.append(idc + 1)
|
||||
e[0] = sparse2bow
|
||||
|
||||
|
||||
def write_to_mindrecord(data, path, shared_num=1):
|
||||
"""generate mindrecord"""
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.abspath(path)
|
||||
|
||||
writer = FileWriter(path, shared_num)
|
||||
data_schema = {
|
||||
"src_tokens": {"type": "int32", "shape": [-1]},
|
||||
"src_tokens_length": {"type": "int32", "shape": [-1]},
|
||||
"label_idx": {"type": "int32", "shape": [-1]}
|
||||
}
|
||||
writer.add_schema(data_schema, "fasttext")
|
||||
for item in data:
|
||||
item['src_tokens'] = np.array(item['src_tokens'], dtype=np.int32)
|
||||
item['src_tokens_length'] = np.array(item['src_tokens_length'], dtype=np.int32)
|
||||
item['label_idx'] = np.array(item['label_idx'], dtype=np.int32)
|
||||
writer.write_raw_data([item])
|
||||
writer.commit()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--train_file', type=str, required=True, help='train dataset file path')
|
||||
parser.add_argument('--test_file', type=str, required=True, help='test dataset file path')
|
||||
parser.add_argument('--class_num', type=int, required=True, help='Dataset classe number')
|
||||
parser.add_argument('--ngram', type=int, default=2, required=False)
|
||||
parser.add_argument('--max_len', type=int, required=False, help='max length sentence in dataset')
|
||||
parser.add_argument('--bucket', type=ast.literal_eval, default=[64, 128, 467], help='bucket sequence length.')
|
||||
parser.add_argument('--test_bucket', type=ast.literal_eval, default=[467], help='bucket sequence length.')
|
||||
parser.add_argument('--is_hashed', type=bool, default=False, help='add hash trick for dataset')
|
||||
parser.add_argument('--feature_size', type=int, default=10000000, help='hash feature size')
|
||||
|
||||
args = parser.parse_args()
|
||||
pprint.PrettyPrinter().pprint(args.__dict__)
|
||||
train_feature_dicts = {}
|
||||
for i in args.bucket:
|
||||
train_feature_dicts[i] = []
|
||||
test_feature_dicts = {}
|
||||
for i in args.test_bucket:
|
||||
test_feature_dicts[i] = []
|
||||
|
||||
g_d = FastTextDataPreProcess(train_path=args.train_file,
|
||||
test_file=args.test_file,
|
||||
max_length=args.max_len,
|
||||
ngram=args.ngram,
|
||||
class_num=args.class_num,
|
||||
train_feature_dict=train_feature_dicts,
|
||||
buckets=args.bucket,
|
||||
test_feature_dict=test_feature_dicts,
|
||||
test_bucket=args.test_bucket,
|
||||
is_hashed=args.is_hashed,
|
||||
feature_size=args.feature_size)
|
||||
train_data_example, test_data_example = g_d.load()
|
||||
print("Data preprocess done")
|
||||
print("Writing train data to MindRecord file......")
|
||||
|
||||
for i in args.bucket:
|
||||
write_to_mindrecord(train_data_example[i], './train_dataset_bs_' + str(i) + '.mindrecord', 1)
|
||||
|
||||
print("Writing test data to MindRecord file.....")
|
||||
for k in args.test_bucket:
|
||||
|
||||
write_to_mindrecord(test_data_example[k], './test_dataset_bs_' + str(k) + '.mindrecord', 1)
|
||||
|
||||
print("All done.....")
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText model."""
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import XavierUniform
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
class FastText(nn.Cell):
|
||||
"""
|
||||
FastText model
|
||||
Args:
|
||||
|
||||
vocab_size: vocabulary size
|
||||
embedding_dims: The size of each embedding vector
|
||||
num_class: number of labels
|
||||
"""
|
||||
def __init__(self, vocab_size, embedding_dims, num_class):
|
||||
super(FastText, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.embeding_dims = embedding_dims
|
||||
self.num_class = num_class
|
||||
self.embeding_func = nn.Embedding(vocab_size=self.vocab_size,
|
||||
embedding_size=self.embeding_dims,
|
||||
padding_idx=0, embedding_table='Zeros')
|
||||
self.fc = nn.Dense(self.embeding_dims, out_channels=self.num_class,
|
||||
weight_init=XavierUniform(1)).to_float(mstype.float16)
|
||||
self.reducesum = P.ReduceSum()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.squeeze = P.Squeeze(axis=1)
|
||||
self.cast = P.Cast()
|
||||
self.tile = P.Tile()
|
||||
self.realdiv = P.RealDiv()
|
||||
self.fill = P.Fill()
|
||||
self.log_softmax = nn.LogSoftmax(axis=1)
|
||||
def construct(self, src_tokens, src_token_length):
|
||||
"""
|
||||
construct network
|
||||
Args:
|
||||
|
||||
src_tokens: source sentences
|
||||
src_token_length: source sentences length
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor], network outputs
|
||||
"""
|
||||
src_tokens = self.embeding_func(src_tokens)
|
||||
embeding = self.reducesum(src_tokens, 1)
|
||||
|
||||
length_tiled = self.tile(src_token_length, (1, self.embeding_dims))
|
||||
|
||||
embeding = self.realdiv(embeding, length_tiled)
|
||||
|
||||
embeding = self.cast(embeding, mstype.float16)
|
||||
classifer = self.fc(embeding)
|
||||
classifer = self.cast(classifer, mstype.float32)
|
||||
|
||||
return classifer
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText for train"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import nn
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore import context
|
||||
from src.fasttext_model import FastText
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type not in (0, 1):
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
class FastTextNetWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide FastText training loss
|
||||
|
||||
Args:
|
||||
vocab_size: vocabulary size
|
||||
embedding_dims: The size of each embedding vector
|
||||
num_class: number of labels
|
||||
"""
|
||||
def __init__(self, vocab_size, embedding_dims, num_class):
|
||||
super(FastTextNetWithLoss, self).__init__()
|
||||
self.fasttext = FastText(vocab_size, embedding_dims, num_class)
|
||||
self.loss_func = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
self.squeeze = P.Squeeze(axis=1)
|
||||
self.print = P.Print()
|
||||
|
||||
def construct(self, src_tokens, src_tokens_lengths, label_idx):
|
||||
"""
|
||||
FastText network with loss.
|
||||
"""
|
||||
predict_score = self.fasttext(src_tokens, src_tokens_lengths)
|
||||
label_idx = self.squeeze(label_idx)
|
||||
predict_score = self.loss_func(predict_score, label_idx)
|
||||
|
||||
return predict_score
|
||||
|
||||
|
||||
class FastTextTrainOneStepCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of fasttext network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(FastTextTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def set_sens(self, value):
|
||||
self.sens = value
|
||||
|
||||
def construct(self,
|
||||
src_token_text,
|
||||
src_tokens_text_length,
|
||||
label_idx_tag):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(src_token_text,
|
||||
src_tokens_text_length,
|
||||
label_idx_tag)
|
||||
grads = self.grad(self.network, weights)(src_token_text,
|
||||
src_tokens_text_length,
|
||||
label_idx_tag,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText data loader"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
|
||||
def load_dataset(dataset_path,
|
||||
batch_size,
|
||||
epoch_count=1,
|
||||
rank_size=1,
|
||||
rank_id=0,
|
||||
bucket=None,
|
||||
shuffle=True):
|
||||
"""dataset loader"""
|
||||
def batch_per_bucket(bucket_length, input_file):
|
||||
input_file = input_file +'/train_dataset_bs_' + str(bucket_length) + '.mindrecord'
|
||||
if not input_file:
|
||||
raise FileNotFoundError("input file parameter must not be empty.")
|
||||
|
||||
ds = de.MindDataset(input_file,
|
||||
columns_list=['src_tokens', 'src_tokens_length', 'label_idx'],
|
||||
shuffle=shuffle,
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id,
|
||||
num_parallel_workers=8)
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print(f"Dataset size: {ori_dataset_size}")
|
||||
repeat_count = epoch_count
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
ds = ds.map(operations=type_cast_op, input_columns="src_tokens")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="src_tokens_length")
|
||||
ds = ds.map(operations=type_cast_op, input_columns="label_idx")
|
||||
|
||||
ds = ds.rename(input_columns=['src_tokens', 'src_tokens_length', 'label_idx'],
|
||||
output_columns=['src_token_text', 'src_tokens_text_length', 'label_idx_tag'])
|
||||
ds = ds.batch(batch_size, drop_remainder=False)
|
||||
ds = ds.repeat(repeat_count)
|
||||
return ds
|
||||
for i, _ in enumerate(bucket):
|
||||
bucket_len = bucket[i]
|
||||
ds_per = batch_per_bucket(bucket_len, dataset_path)
|
||||
if i == 0:
|
||||
ds = ds_per
|
||||
else:
|
||||
ds = ds + ds_per
|
||||
ds = ds.shuffle(ds.get_dataset_size())
|
||||
ds.channel_name = 'fasttext'
|
||||
|
||||
return ds
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Learning rate utilities."""
|
||||
from math import ceil
|
||||
import numpy as np
|
||||
|
||||
def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup_steps=1000, power=1.0):
|
||||
"""
|
||||
Implements of polynomial decay learning rate scheduler which cycles by default.
|
||||
|
||||
Args:
|
||||
lr (float): Initial learning rate.
|
||||
warmup_steps (int): Warmup steps.
|
||||
decay_steps (int): Decay steps.
|
||||
total_update_num (int): Total update steps.
|
||||
min_lr (float): Min learning.
|
||||
power (float): Power factor.
|
||||
|
||||
Returns:
|
||||
np.ndarray, learning rate of each step.
|
||||
"""
|
||||
lrs = np.zeros(shape=total_update_num, dtype=np.float32)
|
||||
|
||||
if decay_steps <= 0:
|
||||
raise ValueError("`decay_steps` must larger than 1.")
|
||||
|
||||
_start_step = 0
|
||||
if 0 < warmup_steps < total_update_num:
|
||||
warmup_end_lr = lr
|
||||
warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr
|
||||
lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps)
|
||||
_start_step = warmup_steps
|
||||
|
||||
decay_steps = decay_steps
|
||||
for step in range(_start_step, total_update_num):
|
||||
_step = step - _start_step
|
||||
ratio = ceil(_step / decay_steps)
|
||||
ratio = 1 if ratio < 1 else ratio
|
||||
_decay_steps = decay_steps * ratio
|
||||
lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr
|
||||
|
||||
return lrs
|
|
@ -0,0 +1,202 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText for train"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import Callback, TimeMonitor
|
||||
from mindspore.communication import management as MultiAscend
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.load_dataset import load_dataset
|
||||
from src.lr_schedule import polynomial_decay_scheduler
|
||||
from src.fasttext_train import FastTextTrainOneStepCell, FastTextNetWithLoss
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_path', type=str, required=True, help='FastText input data file path.')
|
||||
parser.add_argument('--data_name', type=str, required=True, default='ag', help='dataset name. eg. ag, dbpedia')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.data_name == "ag":
|
||||
from src.config import config_ag as config
|
||||
elif args.data_name == 'dbpedia':
|
||||
from src.config import config_db as config
|
||||
elif args.data_name == 'yelp_p':
|
||||
from src.config import config_yelpp as config
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
set_seed(5)
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
rank_id = os.getenv('DEVICE_ID')
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, per_print_times=1, rank_ids=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.rank_id = rank_ids
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = get_ms_timestamp()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Monitor the loss in training."""
|
||||
global time_stamp_first
|
||||
time_stamp_current = get_ms_timestamp()
|
||||
cb_params = run_context.original_args()
|
||||
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
with open("./loss_{}.log".format(self.rank_id), "a+") as f:
|
||||
f.write("time: {}, epoch: {}, step: {}, loss: {}".format(
|
||||
time_stamp_current - time_stamp_first,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs.asnumpy())))
|
||||
f.write('\n')
|
||||
|
||||
|
||||
def _build_training_pipeline(pre_dataset):
|
||||
"""
|
||||
Build training pipeline
|
||||
|
||||
Args:
|
||||
pre_dataset: preprocessed dataset
|
||||
"""
|
||||
net_with_loss = FastTextNetWithLoss(config.vocab_size, config.embedding_dims, config.num_class)
|
||||
net_with_loss.init_parameters_data()
|
||||
if config.pretrain_ckpt_dir:
|
||||
parameter_dict = load_checkpoint(config.pretrain_ckpt_dir)
|
||||
load_param_into_net(net_with_loss, parameter_dict)
|
||||
if pre_dataset is None:
|
||||
raise ValueError("pre-process dataset must be provided")
|
||||
|
||||
#get learning rate
|
||||
update_steps = config.epoch * pre_dataset.get_dataset_size()
|
||||
decay_steps = pre_dataset.get_dataset_size()
|
||||
rank_size = os.getenv("RANK_SIZE")
|
||||
if isinstance(rank_size, int):
|
||||
raise ValueError("RANK_SIZE must be integer")
|
||||
if rank_size is not None and int(rank_size) > 1:
|
||||
base_lr = config.lr
|
||||
else:
|
||||
base_lr = config.lr / 10
|
||||
print("+++++++++++Total update steps ", update_steps)
|
||||
lr = Tensor(polynomial_decay_scheduler(lr=base_lr,
|
||||
min_lr=config.min_lr,
|
||||
decay_steps=decay_steps,
|
||||
total_update_num=update_steps,
|
||||
warmup_steps=config.warmup_steps,
|
||||
power=config.poly_lr_scheduler_power), dtype=mstype.float32)
|
||||
optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.999)
|
||||
|
||||
net_with_grads = FastTextTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||
net_with_grads.set_train(True)
|
||||
model = Model(net_with_grads)
|
||||
loss_monitor = LossCallBack(rank_ids=rank_id)
|
||||
dataset_size = pre_dataset.get_dataset_size()
|
||||
time_monitor = TimeMonitor(data_size=dataset_size)
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps,
|
||||
keep_checkpoint_max=config.keep_ckpt_max)
|
||||
callbacks = [time_monitor, loss_monitor]
|
||||
if rank_size is None or int(rank_size) == 1:
|
||||
ckpt_callback = ModelCheckpoint(prefix='fasttext',
|
||||
directory=os.path.join('./', 'ckpe_{}'.format(os.getenv("DEVICE_ID"))),
|
||||
config=ckpt_config)
|
||||
callbacks.append(ckpt_callback)
|
||||
if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
|
||||
ckpt_callback = ModelCheckpoint(prefix='fasttext',
|
||||
directory=os.path.join('./', 'ckpe_{}'.format(os.getenv("DEVICE_ID"))),
|
||||
config=ckpt_config)
|
||||
callbacks.append(ckpt_callback)
|
||||
print("Prepare to Training....")
|
||||
epoch_size = pre_dataset.get_repeat_count()
|
||||
print("Epoch size ", epoch_size)
|
||||
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
|
||||
print(f" | Rank {MultiAscend.get_rank()} Call model train.")
|
||||
model.train(epoch=config.epoch, train_dataset=pre_dataset, callbacks=callbacks, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def train_single(input_file_path):
|
||||
"""
|
||||
Train model on single device
|
||||
Args:
|
||||
input_file_path: preprocessed dataset path
|
||||
"""
|
||||
print("Staring training on single device.")
|
||||
preprocessed_data = load_dataset(dataset_path=input_file_path,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epoch_count,
|
||||
bucket=config.buckets)
|
||||
_build_training_pipeline(preprocessed_data)
|
||||
|
||||
|
||||
def set_parallel_env():
|
||||
context.reset_auto_parallel_context()
|
||||
MultiAscend.init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=MultiAscend.get_group_size(),
|
||||
gradients_mean=True)
|
||||
def train_paralle(input_file_path):
|
||||
"""
|
||||
Train model on multi device
|
||||
Args:
|
||||
input_file_path: preprocessed dataset path
|
||||
"""
|
||||
set_parallel_env()
|
||||
print("Starting traning on mutiple devices. |~ _ ~| |~ _ ~| |~ _ ~| |~ _ ~|")
|
||||
preprocessed_data = load_dataset(dataset_path=input_file_path,
|
||||
batch_size=config.batch_size,
|
||||
epoch_count=config.epoch_count,
|
||||
rank_size=MultiAscend.get_group_size(),
|
||||
rank_id=MultiAscend.get_rank(),
|
||||
bucket=config.buckets,
|
||||
shuffle=False)
|
||||
_build_training_pipeline(preprocessed_data)
|
||||
|
||||
if __name__ == "__main__":
|
||||
_rank_size = os.getenv("RANK_SIZE")
|
||||
if _rank_size is not None and int(_rank_size) > 1:
|
||||
train_paralle(args.data_path)
|
||||
else:
|
||||
train_single(args.data_path)
|
Loading…
Reference in New Issue