add fasttext to model_zoo

This commit is contained in:
zhaojichen 2020-12-21 16:28:46 +08:00
parent 5671b177f4
commit edc48b48b8
15 changed files with 1565 additions and 0 deletions

View File

@ -46,6 +46,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [BERT[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/bert/README.md)
- [TinyBERT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/tinybert/README.md)
- [GNMT V2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2/README.md)
- [FastText](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext/README.md)
- [LSTM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm/README.md)
- [MASS](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/mass/README.md)
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)

View File

@ -0,0 +1,267 @@
![](https://www.mindspore.cn/static/img/logo.a3e472c9.png)
<!-- TOC -->
- [FastText](#fasttext)
- [Model Structure](#model-structure)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Dataset Preparation](#dataset-preparation)
- [Configuration File](#configuration-file)
- [Training Process](#training-process)
- [Inference Process](#inference-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [Random Situation Description](#random-situation-description)
- [Others](#others)
- [ModelZoo HomePage](#modelzoo-homepage)
<!-- /TOC -->
# [FastText](#contents)
FastText is a fast text classification algorithm, which is simple and efficient. It was proposed by Armand
Joulin, Tomas Mikolov etc. in the artical "Bag of Tricks for Efficient Text Classification" in 2016. It is similar to
CBOW in model architecture, where the middle word is replace by a label. FastText adopts ngram feature as addition feature
to get some information about words. It speeds up training and testing while maintaining high percision, and widly used
in various tasks of text classification.
[Paper](https://arxiv.org/pdf/1607.01759.pdf): "Bag of Tricks for Efficient Text Classification", 2016, A. Joulin, E. Grave, P. Bojanowski, and T. Mikolov
# [Model Structure](#contents)
The FastText model mainly consists of an input layer, hidden layer and output layer, where the input is a sequence of words (text or sentence).
The output layer is probability that the words sequence belongs to different categories. The hidden layer is formed by average of multiple word vector.
The feature is mapped to the hidden layer through linear transformation, and then mapped to the label from the hidden layer.
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network
architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
- AG's news topic classification dataset
- DBPedia Ontology Classification Dataset
- Yelp Review Polarity Dataset
# [Environment Requirements](#content)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://gitee.com/mindspore/mindspore)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#content)
After dataset preparation, you can start training and evaluation as follows:
```bash
# run training example
cd ./scripts
sh run_standalone_train.sh [TRAIN_DATASET]
# run distributed training example
sh run_distribute_train.sh [TRAIN_DATASET] [RANK_TABLE_PATH]
# run evaluation example
sh run_eval.sh [EVAL_DATASET_PATH] [DATASET_NAME] [MODEL_CKPT]
```
# [Script Description](#content)
The FastText network script and code result are as follows:
```text
├── fasttext
├── README.md // Introduction of FastText model.
├── src
│ ├──config.py // Configuration instance definition.
│ ├──create_dataset.py // Dataset preparation.
│ ├──fasttext_model.py // FastText model architecture.
│ ├──fasttext_train.py // Use FastText model architecture.
│ ├──load_dataset.py // Dataset loader to feed into model.
│ ├──lr_scheduler.py // Learning rate scheduler.
├── scripts
│ ├──run_distributed_train.sh // shell script for distributed train on ascend.
│ ├──run_eval.sh // shell script for standalone eval on ascend.
│ ├──run_standalone_train.sh // shell script for standalone eval on ascend.
├── eval.py // Infer API entry.
├── requirements.txt // Requirements of third party package.
├── train.py // Train API entry.
```
## [Dataset Preparation](#content)
- Download the AG's News Topic Classification Dataset, DBPedia Ontology Classification Dataset and Yelp Review Polarity Dataset. Unzip datasets to any path you want.
- Run the following scripts to do data preprocess and convert the original data to mindrecord for training and evaluation.
``` bash
cd scripts
sh creat_dataset.sh [SOURCE_DATASET_PATH] [DATASET_NAME]
```
## [Configuration File](#content)
Parameters for both training and evaluation can be set in config.py. All the datasets are using same parameter name, parameters value could be changed according the needs.
- Network Parameters
```text
vocab_size # vocabulary size.
buckets # bucket sequence length.
batch_size # batch size of input dataset.
embedding_dims # The size of each embedding vector.
num_class # number of labels.
epoch # total training epochs.
lr # initial learning rate.
min_lr # minimum learning rate.
warmup_steps # warm up steps.
poly_lr_scheduler_power # a value used to calculate decayed learning rate.
pretrain_ckpt_dir # pretrain checkpoint direction.
keep_ckpt_max # Max ckpt files number.
```
## [Training Process](#content)
- Start task training on a single device and run the shell script
```bash
cd ./scripts
sh run_standalone_train.sh [DATASET_PATH]
```
- Running scripts for distributed training of FastText. Task training on multiple device and run the following command in bash to be executed in `scripts/`:
``` bash
cd ./scripts
sh run_distributed_train.sh [DATASET_PATH] [RANK_TABLE_PATH]
```
## [Inference Process](#content)
- Running scripts for evaluation of FastText. The commdan as below.
``` bash
cd ./scripts
sh run_eval.sh [DATASET_PATH] [DATASET_NAME] [MODEL_CKPT]
```
Note: The `DATASET_PATH` is path to mindrecord. eg. /dataset_path/*.mindrecord
# [Model Description](#content)
## [Performance](#content)
### Training Performance
| Parameters | Ascend |
| -------------------------- | -------------------------------------------------------------- |
| Resource | Ascend 910 |
| uploaded Date | 12/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | AG's News Topic Classification Dataset |
| Training Parameters | epoch=5, batch_size=128 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Speed | 112ms/step (8pcs) |
| Total Time | 66s (8pcs) |
| Loss | 0.00082 |
| Params (M) | 22 |
| Checkpoint for inference | 254M (.ckpt file) |
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
| Parameters | Ascend |
| -------------------------- | -------------------------------------------------------------- |
| Resource | Ascend 910 |
| uploaded Date | 11/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | DBPedia Ontology Classification Dataset |
| Training Parameters | epoch=5, batch_size=128 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Speed | 60ms/step (8pcs) |
| Total Time | 164s (8pcs) |
| Loss | 2.6e-5 |
| Params (M) | 106 |
| Checkpoint for inference | 1.2G (.ckpt file) |
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
| Parameters | Ascend |
| -------------------------- | -------------------------------------------------------------- |
| Resource | Ascend 910 |
| uploaded Date | 11/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | Yelp Review Polarity Dataset |
| Training Parameters | epoch=5, batch_size=128 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Speed | 74ms/step (8pcs) |
| Total Time | 195s (8pcs) |
| Loss | 7.7e-4 |
| Params (M) | 103 |
| Checkpoint for inference | 1.2G (.ckpt file) |
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Resource | Ascend 910 |
| Uploaded Date | 12/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | AG's News Topic Classification Dataset |
| batch_size | 128 |
| Total Time | 66s |
| outputs | label index |
| Accuracy | 92.53 |
| Model for inference | 254M (.ckpt file) |
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Resource | Ascend 910 |
| Uploaded Date | 12/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | DBPedia Ontology Classification Dataset |
| batch_size | 128 |
| Total Time | 164s |
| outputs | label index |
| Accuracy | 98.6 |
| Model for inference | 1.2G (.ckpt file) |
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Resource | Ascend 910 |
| Uploaded Date | 12/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | Yelp Review Polarity Dataset |
| batch_size | 128 |
| Total Time | 195s |
| outputs | label index |
| Accuracy | 95.7 |
| Model for inference | 1.2G (.ckpt file) |
# [Random Situation Description](#content)
There only one random situation.
- Initialization of some model weights.
Some seeds have already been set in train.py to avoid the randomness of weight initialization.
# [Others](#others)
This model has been validated in the Ascend environment and is not validated on the CPU and GPU.
# [ModelZoo HomePage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)

View File

@ -0,0 +1,118 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FastText for Evaluation"""
import argparse
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.ops.operations as P
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC
from mindspore import context
from src.fasttext_model import FastText
parser = argparse.ArgumentParser(description='fasttext')
parser.add_argument('--data_path', type=str, help='infer dataset path..')
parser.add_argument('--data_name', type=str, required=True, default='ag',
help='dataset name. eg. ag, dbpedia')
parser.add_argument("--model_ckpt", type=str, required=True,
help="existed checkpoint address.")
args = parser.parse_args()
if args.data_name == "ag":
from src.config import config_ag as config
target_label1 = ['0', '1', '2', '3']
elif args.data_name == 'dbpedia':
from src.config import config_db as config
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
elif args.data_name == 'yelp_p':
from src.config import config_yelpp as config
target_label1 = ['0', '1']
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend")
class FastTextInferCell(nn.Cell):
"""
Encapsulation class of FastText network infer.
Args:
network (nn.Cell): FastText model.
Returns:
Tuple[Tensor, Tensor], predicted_ids
"""
def __init__(self, network):
super(FastTextInferCell, self).__init__(auto_prefix=False)
self.network = network
self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True)
self.log_softmax = nn.LogSoftmax(axis=1)
def construct(self, src_tokens, src_tokens_lengths):
"""construct fasttext infer cell"""
prediction = self.network(src_tokens, src_tokens_lengths)
predicted_idx = self.log_softmax(prediction)
predicted_idx, _ = self.argmax(predicted_idx)
return predicted_idx
def load_infer_dataset(batch_size, datafile):
"""data loader for infer"""
ds = de.MindDataset(datafile, columns_list=['src_tokens', 'src_tokens_length', 'label_idx'])
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="src_tokens")
ds = ds.map(operations=type_cast_op, input_columns="src_tokens_length")
ds = ds.map(operations=type_cast_op, input_columns="label_idx")
ds = ds.batch(batch_size=batch_size, drop_remainder=True)
return ds
def run_fasttext_infer():
"""run infer with FastText"""
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path)
fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class)
parameter_dict = load_checkpoint(args.model_ckpt)
load_param_into_net(fasttext_model, parameter_dict=parameter_dict)
ft_infer = FastTextInferCell(fasttext_model)
model = Model(ft_infer)
predictions = []
target_sens = []
for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
target_sens.append(batch['label_idx'])
src_tokens = Tensor(batch['src_tokens'], mstype.int32)
src_tokens_length = Tensor(batch['src_tokens_length'], mstype.int32)
predicted_idx = model.predict(src_tokens, src_tokens_length)
predictions.append(predicted_idx.asnumpy())
from sklearn.metrics import accuracy_score, classification_report
target_sens = np.array(target_sens).flatten()
predictions = np.array(predictions).flatten()
acc = accuracy_score(target_sens, predictions)
result_report = classification_report(target_sens, predictions, target_names=target_label1)
print("********Accuracy: ", acc)
print(result_report)
if __name__ == '__main__':
run_fasttext_infer()

View File

@ -0,0 +1,3 @@
spacy
sklearn
en_core_web_lg

View File

@ -0,0 +1,83 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh create_dataset.sh SOURCE_DATASET_PATH DATASET_NAME"
echo "for example: sh create_dataset.sh /home/workspace/ag_news_csv ag"
echo "DATASET_NAME including ag, dbpedia, and yelp_p"
echo "It is better to use absolute path."
echo "=============================================================================================================="
ulimit -u unlimited
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
SOURCE_DATASET_PATH=$(get_real_path $1)
DATASET_NAME=$2
export DEVICE_NUM=1
export DEVICE_ID=5
export RANK_ID=0
export RANK_SIZE=1
if [ $DATASET_NAME == 'ag' ];
then
echo "Begin to process ag news data"
if [ -d "ag" ];
then
rm -rf ./ag
fi
mkdir ./ag
cd ./ag || exit
echo "start data preprocess for device $DEVICE_ID"
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 4 --max_len 467 --bucket [64,128,467] --test_bucket [467]
cd ..
fi
if [ $DATASET_NAME == 'dbpedia' ];
then
echo "Begin to process dbpedia data"
if [ -d "dbpedia" ];
then
rm -rf ./dbpedia
fi
mkdir ./dbpedia
cd ./dbpedia || exit
echo "start data preprocess for device $DEVICE_ID"
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 14 --max_len 3013 --bucket [128,512,3013] --test_bucket [1120]
cd ..
fi
if [ $DATASET_NAME == 'yelp_p' ];
then
echo "Begin to process ag news data"
if [ -d "yelp_p" ];
then
rm -rf ./yelp_p
fi
mkdir ./yelp_p
cd ./yelp_p || exit
echo "start data preprocess for device $DEVICE_ID"
python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 2 --max_len 2955 --bucket [64,128,256,512,2955] --test_bucket [2955]
cd ..
fi

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_distributed_train.sh DATASET_PATH RANK_TABLE_PATH"
echo "for example: sh run_distributed_train.sh /home/workspace/ag /home/workspace/rank_table_file.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET=$(get_real_path $1)
echo $DATASET
DATANAME=$(basename $DATASET)
RANK_TABLE_PATH=$(get_real_path $2)
echo $DATANAME
if [ ! -d $DATASET ]
then
echo "Error: DATA_PATH=$DATASET is not a file"
exit 1
fi
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_TABLE_FILE=$RANK_TABLE_PATH
echo $RANK_TABLE_FILE
export RANK_SIZE=8
export DEVICE_NUM=8
for((i=0;i<=7;i++));
do
rm -rf ${current_exec_path}/device$i
mkdir ${current_exec_path}/device$i
cd ${current_exec_path}/device$i
cp ../../*.py ./
cp -r ../../src ./
cp -r ../*.sh ./
export RANK_ID=$i
export DEVICE_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
python ../../train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
cd ${current_exec_path}
done
cd ${current_exec_path}

View File

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_eval.sh DATASET_PATH DATASET_NAME MODEL_CKPT"
echo "for example: sh run_eval.sh /home/workspace/ag/test*.mindrecord ag device0/ckpt0/fasttext-5-118.ckpt"
echo "It is better to use absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET=$(get_real_path $1)
echo $DATASET
DATANAME=$2
MODEL_CKPT=$(get_real_path $3)
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=5
export RANK_ID=0
export RANK_SIZE=1
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cp -r ../scripts/*.sh ./eval
cd ./eval || exit
echo "start training for device $DEVICE_ID"
env > env.log
python ../../eval.py --data_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
cd ..

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_standalone_train.sh DATASET_PATH"
echo "for example: sh run_standalone_train.sh /home/workspace/ag"
echo "It is better to use absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET=$(get_real_path $1)
echo $DATASET
DATANAME=$(basename $DATASET)
echo $DATANAME
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cp -r ../scripts/*.sh ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
cd ..

View File

@ -0,0 +1,72 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" :===========================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config_yelpp = ed({
'vocab_size': 6414979,
'buckets': [64, 128, 256, 512, 2955],
'batch_size': 128,
'embedding_dims': 16,
'num_class': 2,
'epoch': 5,
'lr': 0.02,
'min_lr': 1e-6,
'decay_steps': 549,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.5,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 549,
'keep_ckpt_max': 10,
})
config_db = ed({
'vocab_size': 6596536,
'buckets': [128, 512, 3013],
'batch_size': 128,
'embedding_dims': 16,
'num_class': 14,
'epoch': 5,
'lr': 0.05,
'min_lr': 1e-6,
'decay_steps': 549,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.5,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 548,
'keep_ckpt_max': 10,
})
config_ag = ed({
'vocab_size': 1383812,
'buckets': [64, 128, 467],
'batch_size': 128,
'embedding_dims': 16,
'num_class': 4,
'epoch': 5,
'lr': 0.05,
'min_lr': 1e-6,
'decay_steps': 115,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.5,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 116,
'keep_ckpt_max': 10,
})

View File

@ -0,0 +1,316 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FastText data preprocess"""
import csv
import os
import re
import argparse
import pprint
import ast
import html
import numpy as np
import spacy
from sklearn.feature_extraction import FeatureHasher
from mindspore.mindrecord import FileWriter
class FastTextDataPreProcess():
"""FastText data preprocess"""
def __init__(self, train_path,
test_file,
max_length,
class_num,
ngram,
train_feature_dict,
buckets,
test_feature_dict,
test_bucket,
is_hashed,
feature_size):
self.train_path = train_path
self.test_path = test_file
self.max_length = max_length
self.class_num = class_num
self.train_feature_dict = train_feature_dict
self.test_feature_dict = test_feature_dict
self.test_bucket = test_bucket
self.is_hashed = is_hashed
self.feature_size = feature_size
self.buckets = buckets
self.ngram = ngram
self.text_greater = '>'
self.text_less = '<'
self.ngram2idx = dict()
self.idx2gram = dict()
self.non_str = '\\'
self.end_string = ['.', '?', '!']
self.ngram2idx['PAD'] = 0
self.idx2gram[0] = 'PAD'
self.ngram2idx['UNK'] = 1
self.idx2gram[1] = 'UNK'
self.str_html = re.compile(r'<[^>]+>')
def load(self):
"""data preprocess loader"""
train_dataset_list = []
test_dataset_list = []
spacy_nlp = spacy.load('en_core_web_lg', disable=['parser', 'tagger', 'ner'])
spacy_nlp.add_pipe(spacy_nlp.create_pipe('sentencizer'))
with open(self.train_path, 'r', newline='', encoding='utf-8') as src_file:
reader = csv.reader(src_file, delimiter=",", quotechar='"')
for _, _pair_sen in enumerate(reader):
label_idx = int(_pair_sen[0]) - 1
if len(_pair_sen) == 3:
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
src_text2=_pair_sen[2],
spacy_nlp=spacy_nlp,
train_mode=True)
src_tokens_length = len(src_tokens)
elif len(_pair_sen) == 2:
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
src_text2=None,
spacy_nlp=spacy_nlp,
train_mode=True)
src_tokens_length = len(src_tokens)
elif len(_pair_sen) == 4:
if _pair_sen[2]:
sen_o_t = _pair_sen[1] + ' ' + _pair_sen[2]
else:
sen_o_t = _pair_sen[1]
src_tokens = self.input_preprocess(src_text1=sen_o_t,
src_text2=_pair_sen[3],
spacy_nlp=spacy_nlp,
train_mode=True)
src_tokens_length = len(src_tokens)
train_dataset_list.append([src_tokens, src_tokens_length, label_idx])
print("Begin to process test data...")
with open(self.test_path, 'r', newline='', encoding='utf-8') as test_file:
reader2 = csv.reader(test_file, delimiter=",", quotechar='"')
for _, _test_sen in enumerate(reader2):
label_idx = int(_test_sen[0]) - 1
if len(_test_sen) == 3:
src_tokens = self.input_preprocess(src_text1=_test_sen[1],
src_text2=_test_sen[2],
spacy_nlp=spacy_nlp,
train_mode=False)
src_tokens_length = len(src_tokens)
elif len(_test_sen) == 2:
src_tokens = self.input_preprocess(src_text1=_test_sen[1],
src_text2=None,
spacy_nlp=spacy_nlp,
train_mode=False)
src_tokens_length = len(src_tokens)
elif len(_test_sen) == 4:
if _test_sen[2]:
sen_o_t = _test_sen[1] + ' ' + _test_sen[2]
else:
sen_o_t = _test_sen[1]
src_tokens = self.input_preprocess(src_text1=sen_o_t,
src_text2=_test_sen[3],
spacy_nlp=spacy_nlp,
train_mode=False)
src_tokens_length = len(src_tokens)
test_dataset_list.append([src_tokens, src_tokens_length, label_idx])
if self.is_hashed:
print("Begin to Hashing Trick......")
features_num = self.feature_size
fh = FeatureHasher(n_features=features_num, alternate_sign=False)
print("FeatureHasher features..", features_num)
self.hash_trick(fh, train_dataset_list)
self.hash_trick(fh, test_dataset_list)
print("Hashing Done....")
# pad train dataset
train_dataset_list_length = len(train_dataset_list)
test_dataset_list_length = len(test_dataset_list)
for l in range(train_dataset_list_length):
bucket_length = self._get_bucket_length(train_dataset_list[l][0], self.buckets)
while len(train_dataset_list[l][0]) < bucket_length:
train_dataset_list[l][0].append(self.ngram2idx['PAD'])
train_dataset_list[l][1] = len(train_dataset_list[l][0])
# pad test dataset
for j in range(test_dataset_list_length):
test_bucket_length = self._get_bucket_length(test_dataset_list[j][0], self.test_bucket)
while len(test_dataset_list[j][0]) < test_bucket_length:
test_dataset_list[j][0].append(self.ngram2idx['PAD'])
test_dataset_list[j][1] = len(test_dataset_list[j][0])
train_example_data = []
test_example_data = []
for idx in range(train_dataset_list_length):
train_example_data.append({
"src_tokens": train_dataset_list[idx][0],
"src_tokens_length": train_dataset_list[idx][1],
"label_idx": train_dataset_list[idx][2],
})
for key in self.train_feature_dict:
if key == train_example_data[idx]['src_tokens_length']:
self.train_feature_dict[key].append(train_example_data[idx])
for h in range(test_dataset_list_length):
test_example_data.append({
"src_tokens": test_dataset_list[h][0],
"src_tokens_length": test_dataset_list[h][1],
"label_idx": test_dataset_list[h][2],
})
for key in self.test_feature_dict:
if key == test_example_data[h]['src_tokens_length']:
self.test_feature_dict[key].append(test_example_data[h])
print("train vocab size is ", len(self.ngram2idx))
return self.train_feature_dict, self.test_feature_dict
def input_preprocess(self, src_text1, src_text2, spacy_nlp, train_mode):
"""data preprocess func"""
src_text1 = src_text1.strip()
if src_text1 and src_text1[-1] not in self.end_string:
src_text1 = src_text1 + '.'
if src_text2:
src_text2 = src_text2.strip()
sent_describe = src_text1 + ' ' + src_text2
else:
sent_describe = src_text1
if self.non_str in sent_describe:
sent_describe = sent_describe.replace(self.non_str, ' ')
sent_describe = html.unescape(sent_describe)
if self.text_less in sent_describe and self.text_greater in sent_describe:
sent_describe = self.str_html.sub('', sent_describe)
doc = spacy_nlp(sent_describe)
bows_token = [token.text for token in doc]
try:
tagged_sent_desc = '<p> ' + ' </s> '.join([s.text for s in doc.sents]) + ' </p>'
except ValueError:
tagged_sent_desc = '<p> ' + sent_describe + ' </p>'
doc = spacy_nlp(tagged_sent_desc)
ngrams = self.generate_gram([token.text for token in doc], num=self.ngram)
bo_ngrams = bows_token + ngrams
if train_mode is True:
for ngms in bo_ngrams:
idx = self.ngram2idx.get(ngms)
if idx is None:
idx = len(self.ngram2idx)
self.ngram2idx[ngms] = idx
self.idx2gram[idx] = ngms
processed_out = [self.ngram2idx[ng] if ng in self.ngram2idx else self.ngram2idx['UNK'] for ng in bo_ngrams]
return processed_out
def _get_bucket_length(self, x, bts):
x_len = len(x)
for index in range(1, len(bts)):
if bts[index-1] < x_len <= bts[index]:
return bts[index]
return bts[0]
def generate_gram(self, words, num=2):
return [' '.join(words[i: i + num]) for i in range(len(words) - num + 1)]
def count2dict(self, lst):
count_dict = dict()
for m in lst:
if str(m) in count_dict:
count_dict[str(m)] += 1
else:
count_dict[str(m)] = 1
return count_dict
def hash_trick(self, hashing, input_data):
trans = hashing.transform((self.count2dict(e[0]) for e in input_data))
for htr, e in zip(trans, input_data):
sparse2bow = list()
for idc, d in zip(htr.indices, htr.data):
for _ in range(int(d)):
sparse2bow.append(idc + 1)
e[0] = sparse2bow
def write_to_mindrecord(data, path, shared_num=1):
"""generate mindrecord"""
if not os.path.isabs(path):
path = os.path.abspath(path)
writer = FileWriter(path, shared_num)
data_schema = {
"src_tokens": {"type": "int32", "shape": [-1]},
"src_tokens_length": {"type": "int32", "shape": [-1]},
"label_idx": {"type": "int32", "shape": [-1]}
}
writer.add_schema(data_schema, "fasttext")
for item in data:
item['src_tokens'] = np.array(item['src_tokens'], dtype=np.int32)
item['src_tokens_length'] = np.array(item['src_tokens_length'], dtype=np.int32)
item['label_idx'] = np.array(item['label_idx'], dtype=np.int32)
writer.write_raw_data([item])
writer.commit()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train_file', type=str, required=True, help='train dataset file path')
parser.add_argument('--test_file', type=str, required=True, help='test dataset file path')
parser.add_argument('--class_num', type=int, required=True, help='Dataset classe number')
parser.add_argument('--ngram', type=int, default=2, required=False)
parser.add_argument('--max_len', type=int, required=False, help='max length sentence in dataset')
parser.add_argument('--bucket', type=ast.literal_eval, default=[64, 128, 467], help='bucket sequence length.')
parser.add_argument('--test_bucket', type=ast.literal_eval, default=[467], help='bucket sequence length.')
parser.add_argument('--is_hashed', type=bool, default=False, help='add hash trick for dataset')
parser.add_argument('--feature_size', type=int, default=10000000, help='hash feature size')
args = parser.parse_args()
pprint.PrettyPrinter().pprint(args.__dict__)
train_feature_dicts = {}
for i in args.bucket:
train_feature_dicts[i] = []
test_feature_dicts = {}
for i in args.test_bucket:
test_feature_dicts[i] = []
g_d = FastTextDataPreProcess(train_path=args.train_file,
test_file=args.test_file,
max_length=args.max_len,
ngram=args.ngram,
class_num=args.class_num,
train_feature_dict=train_feature_dicts,
buckets=args.bucket,
test_feature_dict=test_feature_dicts,
test_bucket=args.test_bucket,
is_hashed=args.is_hashed,
feature_size=args.feature_size)
train_data_example, test_data_example = g_d.load()
print("Data preprocess done")
print("Writing train data to MindRecord file......")
for i in args.bucket:
write_to_mindrecord(train_data_example[i], './train_dataset_bs_' + str(i) + '.mindrecord', 1)
print("Writing test data to MindRecord file.....")
for k in args.test_bucket:
write_to_mindrecord(test_data_example[k], './test_dataset_bs_' + str(k) + '.mindrecord', 1)
print("All done.....")

View File

@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FastText model."""
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.common.initializer import XavierUniform
from mindspore.common import dtype as mstype
class FastText(nn.Cell):
"""
FastText model
Args:
vocab_size: vocabulary size
embedding_dims: The size of each embedding vector
num_class: number of labels
"""
def __init__(self, vocab_size, embedding_dims, num_class):
super(FastText, self).__init__()
self.vocab_size = vocab_size
self.embeding_dims = embedding_dims
self.num_class = num_class
self.embeding_func = nn.Embedding(vocab_size=self.vocab_size,
embedding_size=self.embeding_dims,
padding_idx=0, embedding_table='Zeros')
self.fc = nn.Dense(self.embeding_dims, out_channels=self.num_class,
weight_init=XavierUniform(1)).to_float(mstype.float16)
self.reducesum = P.ReduceSum()
self.expand_dims = P.ExpandDims()
self.squeeze = P.Squeeze(axis=1)
self.cast = P.Cast()
self.tile = P.Tile()
self.realdiv = P.RealDiv()
self.fill = P.Fill()
self.log_softmax = nn.LogSoftmax(axis=1)
def construct(self, src_tokens, src_token_length):
"""
construct network
Args:
src_tokens: source sentences
src_token_length: source sentences length
Returns:
Tuple[Tensor], network outputs
"""
src_tokens = self.embeding_func(src_tokens)
embeding = self.reducesum(src_tokens, 1)
length_tiled = self.tile(src_token_length, (1, self.embeding_dims))
embeding = self.realdiv(embeding, length_tiled)
embeding = self.cast(embeding, mstype.float16)
classifer = self.fc(embeding)
classifer = self.cast(classifer, mstype.float32)
return classifer

View File

@ -0,0 +1,142 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FastText for train"""
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.parameter import ParameterTuple
from mindspore.context import ParallelMode
from mindspore import nn
from mindspore.communication.management import get_group_size
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore import context
from src.fasttext_model import FastText
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type not in (0, 1):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
class FastTextNetWithLoss(nn.Cell):
"""
Provide FastText training loss
Args:
vocab_size: vocabulary size
embedding_dims: The size of each embedding vector
num_class: number of labels
"""
def __init__(self, vocab_size, embedding_dims, num_class):
super(FastTextNetWithLoss, self).__init__()
self.fasttext = FastText(vocab_size, embedding_dims, num_class)
self.loss_func = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
self.squeeze = P.Squeeze(axis=1)
self.print = P.Print()
def construct(self, src_tokens, src_tokens_lengths, label_idx):
"""
FastText network with loss.
"""
predict_score = self.fasttext(src_tokens, src_tokens_lengths)
label_idx = self.squeeze(label_idx)
predict_score = self.loss_func(predict_score, label_idx)
return predict_score
class FastTextTrainOneStepCell(nn.Cell):
"""
Encapsulation class of fasttext network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""
def __init__(self, network, optimizer, sens=1.0):
super(FastTextTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode not in ParallelMode.MODE_LIST:
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.hyper_map = C.HyperMap()
self.cast = P.Cast()
def set_sens(self, value):
self.sens = value
def construct(self,
src_token_text,
src_tokens_text_length,
label_idx_tag):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(src_token_text,
src_tokens_text_length,
label_idx_tag)
grads = self.grad(self.network, weights)(src_token_text,
src_tokens_text_length,
label_idx_tag,
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
succ = self.optimizer(grads)
return F.depend(loss, succ)

View File

@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FastText data loader"""
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC
def load_dataset(dataset_path,
batch_size,
epoch_count=1,
rank_size=1,
rank_id=0,
bucket=None,
shuffle=True):
"""dataset loader"""
def batch_per_bucket(bucket_length, input_file):
input_file = input_file +'/train_dataset_bs_' + str(bucket_length) + '.mindrecord'
if not input_file:
raise FileNotFoundError("input file parameter must not be empty.")
ds = de.MindDataset(input_file,
columns_list=['src_tokens', 'src_tokens_length', 'label_idx'],
shuffle=shuffle,
num_shards=rank_size,
shard_id=rank_id,
num_parallel_workers=8)
ori_dataset_size = ds.get_dataset_size()
print(f"Dataset size: {ori_dataset_size}")
repeat_count = epoch_count
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="src_tokens")
ds = ds.map(operations=type_cast_op, input_columns="src_tokens_length")
ds = ds.map(operations=type_cast_op, input_columns="label_idx")
ds = ds.rename(input_columns=['src_tokens', 'src_tokens_length', 'label_idx'],
output_columns=['src_token_text', 'src_tokens_text_length', 'label_idx_tag'])
ds = ds.batch(batch_size, drop_remainder=False)
ds = ds.repeat(repeat_count)
return ds
for i, _ in enumerate(bucket):
bucket_len = bucket[i]
ds_per = batch_per_bucket(bucket_len, dataset_path)
if i == 0:
ds = ds_per
else:
ds = ds + ds_per
ds = ds.shuffle(ds.get_dataset_size())
ds.channel_name = 'fasttext'
return ds

View File

@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning rate utilities."""
from math import ceil
import numpy as np
def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup_steps=1000, power=1.0):
"""
Implements of polynomial decay learning rate scheduler which cycles by default.
Args:
lr (float): Initial learning rate.
warmup_steps (int): Warmup steps.
decay_steps (int): Decay steps.
total_update_num (int): Total update steps.
min_lr (float): Min learning.
power (float): Power factor.
Returns:
np.ndarray, learning rate of each step.
"""
lrs = np.zeros(shape=total_update_num, dtype=np.float32)
if decay_steps <= 0:
raise ValueError("`decay_steps` must larger than 1.")
_start_step = 0
if 0 < warmup_steps < total_update_num:
warmup_end_lr = lr
warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr
lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps)
_start_step = warmup_steps
decay_steps = decay_steps
for step in range(_start_step, total_update_num):
_step = step - _start_step
ratio = ceil(_step / decay_steps)
ratio = 1 if ratio < 1 else ratio
_decay_steps = decay_steps * ratio
lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr
return lrs

View File

@ -0,0 +1,202 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FastText for train"""
import os
import time
import argparse
from mindspore import context
from mindspore.nn.optim import Adam
from mindspore.common import set_seed
from mindspore.train.model import Model
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.context import ParallelMode
from mindspore.train.callback import Callback, TimeMonitor
from mindspore.communication import management as MultiAscend
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.load_dataset import load_dataset
from src.lr_schedule import polynomial_decay_scheduler
from src.fasttext_train import FastTextTrainOneStepCell, FastTextNetWithLoss
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, required=True, help='FastText input data file path.')
parser.add_argument('--data_name', type=str, required=True, default='ag', help='dataset name. eg. ag, dbpedia')
args = parser.parse_args()
if args.data_name == "ag":
from src.config import config_ag as config
elif args.data_name == 'dbpedia':
from src.config import config_db as config
elif args.data_name == 'yelp_p':
from src.config import config_yelpp as config
def get_ms_timestamp():
t = time.time()
return int(round(t * 1000))
set_seed(5)
time_stamp_init = False
time_stamp_first = 0
rank_id = os.getenv('DEVICE_ID')
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend")
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1, rank_ids=0):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.rank_id = rank_ids
global time_stamp_init, time_stamp_first
if not time_stamp_init:
time_stamp_first = get_ms_timestamp()
time_stamp_init = True
def step_end(self, run_context):
"""Monitor the loss in training."""
global time_stamp_first
time_stamp_current = get_ms_timestamp()
cb_params = run_context.original_args()
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first,
cb_params.cur_epoch_num,
cb_params.cur_step_num,
str(cb_params.net_outputs)))
with open("./loss_{}.log".format(self.rank_id), "a+") as f:
f.write("time: {}, epoch: {}, step: {}, loss: {}".format(
time_stamp_current - time_stamp_first,
cb_params.cur_epoch_num,
cb_params.cur_step_num,
str(cb_params.net_outputs.asnumpy())))
f.write('\n')
def _build_training_pipeline(pre_dataset):
"""
Build training pipeline
Args:
pre_dataset: preprocessed dataset
"""
net_with_loss = FastTextNetWithLoss(config.vocab_size, config.embedding_dims, config.num_class)
net_with_loss.init_parameters_data()
if config.pretrain_ckpt_dir:
parameter_dict = load_checkpoint(config.pretrain_ckpt_dir)
load_param_into_net(net_with_loss, parameter_dict)
if pre_dataset is None:
raise ValueError("pre-process dataset must be provided")
#get learning rate
update_steps = config.epoch * pre_dataset.get_dataset_size()
decay_steps = pre_dataset.get_dataset_size()
rank_size = os.getenv("RANK_SIZE")
if isinstance(rank_size, int):
raise ValueError("RANK_SIZE must be integer")
if rank_size is not None and int(rank_size) > 1:
base_lr = config.lr
else:
base_lr = config.lr / 10
print("+++++++++++Total update steps ", update_steps)
lr = Tensor(polynomial_decay_scheduler(lr=base_lr,
min_lr=config.min_lr,
decay_steps=decay_steps,
total_update_num=update_steps,
warmup_steps=config.warmup_steps,
power=config.poly_lr_scheduler_power), dtype=mstype.float32)
optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.999)
net_with_grads = FastTextTrainOneStepCell(net_with_loss, optimizer=optimizer)
net_with_grads.set_train(True)
model = Model(net_with_grads)
loss_monitor = LossCallBack(rank_ids=rank_id)
dataset_size = pre_dataset.get_dataset_size()
time_monitor = TimeMonitor(data_size=dataset_size)
ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps,
keep_checkpoint_max=config.keep_ckpt_max)
callbacks = [time_monitor, loss_monitor]
if rank_size is None or int(rank_size) == 1:
ckpt_callback = ModelCheckpoint(prefix='fasttext',
directory=os.path.join('./', 'ckpe_{}'.format(os.getenv("DEVICE_ID"))),
config=ckpt_config)
callbacks.append(ckpt_callback)
if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
ckpt_callback = ModelCheckpoint(prefix='fasttext',
directory=os.path.join('./', 'ckpe_{}'.format(os.getenv("DEVICE_ID"))),
config=ckpt_config)
callbacks.append(ckpt_callback)
print("Prepare to Training....")
epoch_size = pre_dataset.get_repeat_count()
print("Epoch size ", epoch_size)
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
print(f" | Rank {MultiAscend.get_rank()} Call model train.")
model.train(epoch=config.epoch, train_dataset=pre_dataset, callbacks=callbacks, dataset_sink_mode=False)
def train_single(input_file_path):
"""
Train model on single device
Args:
input_file_path: preprocessed dataset path
"""
print("Staring training on single device.")
preprocessed_data = load_dataset(dataset_path=input_file_path,
batch_size=config.batch_size,
epoch_count=config.epoch_count,
bucket=config.buckets)
_build_training_pipeline(preprocessed_data)
def set_parallel_env():
context.reset_auto_parallel_context()
MultiAscend.init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=MultiAscend.get_group_size(),
gradients_mean=True)
def train_paralle(input_file_path):
"""
Train model on multi device
Args:
input_file_path: preprocessed dataset path
"""
set_parallel_env()
print("Starting traning on mutiple devices. |~ _ ~| |~ _ ~| |~ _ ~| |~ _ ~|")
preprocessed_data = load_dataset(dataset_path=input_file_path,
batch_size=config.batch_size,
epoch_count=config.epoch_count,
rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank(),
bucket=config.buckets,
shuffle=False)
_build_training_pipeline(preprocessed_data)
if __name__ == "__main__":
_rank_size = os.getenv("RANK_SIZE")
if _rank_size is not None and int(_rank_size) > 1:
train_paralle(args.data_path)
else:
train_single(args.data_path)