diff --git a/model_zoo/README.md b/model_zoo/README.md index a706ac6e954..87690b26e2c 100644 --- a/model_zoo/README.md +++ b/model_zoo/README.md @@ -46,6 +46,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework, - [BERT[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/bert/README.md) - [TinyBERT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/tinybert/README.md) - [GNMT V2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2/README.md) + - [FastText](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext/README.md) - [LSTM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm/README.md) - [MASS](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/mass/README.md) - [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md) diff --git a/model_zoo/official/nlp/fasttext/README.md b/model_zoo/official/nlp/fasttext/README.md new file mode 100644 index 00000000000..77a4095ed8c --- /dev/null +++ b/model_zoo/official/nlp/fasttext/README.md @@ -0,0 +1,267 @@ +![](https://www.mindspore.cn/static/img/logo.a3e472c9.png) + + + +- [FastText](#fasttext) +- [Model Structure](#model-structure) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Dataset Preparation](#dataset-preparation) + - [Configuration File](#configuration-file) + - [Training Process](#training-process) + - [Inference Process](#inference-process) +- [Model Description](#model-description) + - [Performance](#performance) + - [Training Performance](#training-performance) + - [Inference Performance](#inference-performance) +- [Random Situation Description](#random-situation-description) +- [Others](#others) +- [ModelZoo HomePage](#modelzoo-homepage) + + + +# [FastText](#contents) + +FastText is a fast text classification algorithm, which is simple and efficient. It was proposed by Armand +Joulin, Tomas Mikolov etc. in the artical "Bag of Tricks for Efficient Text Classification" in 2016. It is similar to +CBOW in model architecture, where the middle word is replace by a label. FastText adopts ngram feature as addition feature +to get some information about words. It speeds up training and testing while maintaining high percision, and widly used +in various tasks of text classification. + +[Paper](https://arxiv.org/pdf/1607.01759.pdf): "Bag of Tricks for Efficient Text Classification", 2016, A. Joulin, E. Grave, P. Bojanowski, and T. Mikolov + +# [Model Structure](#contents) + +The FastText model mainly consists of an input layer, hidden layer and output layer, where the input is a sequence of words (text or sentence). +The output layer is probability that the words sequence belongs to different categories. The hidden layer is formed by average of multiple word vector. +The feature is mapped to the hidden layer through linear transformation, and then mapped to the label from the hidden layer. + +# [Dataset](#contents) + +Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network +architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. + +- AG's news topic classification dataset +- DBPedia Ontology Classification Dataset +- Yelp Review Polarity Dataset + +# [Environment Requirements](#content) + +- Hardware(Ascend) + - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + - [MindSpore](https://gitee.com/mindspore/mindspore) +- For more information, please check the resources below: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +# [Quick Start](#content) + +After dataset preparation, you can start training and evaluation as follows: + +```bash +# run training example +cd ./scripts +sh run_standalone_train.sh [TRAIN_DATASET] + +# run distributed training example +sh run_distribute_train.sh [TRAIN_DATASET] [RANK_TABLE_PATH] + +# run evaluation example +sh run_eval.sh [EVAL_DATASET_PATH] [DATASET_NAME] [MODEL_CKPT] +``` + +# [Script Description](#content) + +The FastText network script and code result are as follows: + +```text +├── fasttext + ├── README.md // Introduction of FastText model. + ├── src + │ ├──config.py // Configuration instance definition. + │ ├──create_dataset.py // Dataset preparation. + │ ├──fasttext_model.py // FastText model architecture. + │ ├──fasttext_train.py // Use FastText model architecture. + │ ├──load_dataset.py // Dataset loader to feed into model. + │ ├──lr_scheduler.py // Learning rate scheduler. + ├── scripts + │ ├──run_distributed_train.sh // shell script for distributed train on ascend. + │ ├──run_eval.sh // shell script for standalone eval on ascend. + │ ├──run_standalone_train.sh // shell script for standalone eval on ascend. + ├── eval.py // Infer API entry. + ├── requirements.txt // Requirements of third party package. + ├── train.py // Train API entry. +``` + +## [Dataset Preparation](#content) + +- Download the AG's News Topic Classification Dataset, DBPedia Ontology Classification Dataset and Yelp Review Polarity Dataset. Unzip datasets to any path you want. + +- Run the following scripts to do data preprocess and convert the original data to mindrecord for training and evaluation. + + ``` bash + cd scripts + sh creat_dataset.sh [SOURCE_DATASET_PATH] [DATASET_NAME] + ``` + +## [Configuration File](#content) + +Parameters for both training and evaluation can be set in config.py. All the datasets are using same parameter name, parameters value could be changed according the needs. + +- Network Parameters + + ```text + vocab_size # vocabulary size. + buckets # bucket sequence length. + batch_size # batch size of input dataset. + embedding_dims # The size of each embedding vector. + num_class # number of labels. + epoch # total training epochs. + lr # initial learning rate. + min_lr # minimum learning rate. + warmup_steps # warm up steps. + poly_lr_scheduler_power # a value used to calculate decayed learning rate. + pretrain_ckpt_dir # pretrain checkpoint direction. + keep_ckpt_max # Max ckpt files number. + ``` + +## [Training Process](#content) + +- Start task training on a single device and run the shell script + + ```bash + cd ./scripts + sh run_standalone_train.sh [DATASET_PATH] + ``` + +- Running scripts for distributed training of FastText. Task training on multiple device and run the following command in bash to be executed in `scripts/`: + + ``` bash + cd ./scripts + sh run_distributed_train.sh [DATASET_PATH] [RANK_TABLE_PATH] + ``` + +## [Inference Process](#content) + +- Running scripts for evaluation of FastText. The commdan as below. + + ``` bash + cd ./scripts + sh run_eval.sh [DATASET_PATH] [DATASET_NAME] [MODEL_CKPT] + ``` + + Note: The `DATASET_PATH` is path to mindrecord. eg. /dataset_path/*.mindrecord + +# [Model Description](#content) + +## [Performance](#content) + +### Training Performance + +| Parameters | Ascend | +| -------------------------- | -------------------------------------------------------------- | +| Resource | Ascend 910 | +| uploaded Date | 12/21/2020 (month/day/year) | +| MindSpore Version | 1.1.0 | +| Dataset | AG's News Topic Classification Dataset | +| Training Parameters | epoch=5, batch_size=128 | +| Optimizer | Adam | +| Loss Function | Softmax Cross Entropy | +| outputs | probability | +| Speed | 112ms/step (8pcs) | +| Total Time | 66s (8pcs) | +| Loss | 0.00082 | +| Params (M) | 22 | +| Checkpoint for inference | 254M (.ckpt file) | +| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) | + +| Parameters | Ascend | +| -------------------------- | -------------------------------------------------------------- | +| Resource | Ascend 910 | +| uploaded Date | 11/21/2020 (month/day/year) | +| MindSpore Version | 1.1.0 | +| Dataset | DBPedia Ontology Classification Dataset | +| Training Parameters | epoch=5, batch_size=128 | +| Optimizer | Adam | +| Loss Function | Softmax Cross Entropy | +| outputs | probability | +| Speed | 60ms/step (8pcs) | +| Total Time | 164s (8pcs) | +| Loss | 2.6e-5 | +| Params (M) | 106 | +| Checkpoint for inference | 1.2G (.ckpt file) | +| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) | + +| Parameters | Ascend | +| -------------------------- | -------------------------------------------------------------- | +| Resource | Ascend 910 | +| uploaded Date | 11/21/2020 (month/day/year) | +| MindSpore Version | 1.1.0 | +| Dataset | Yelp Review Polarity Dataset | +| Training Parameters | epoch=5, batch_size=128 | +| Optimizer | Adam | +| Loss Function | Softmax Cross Entropy | +| outputs | probability | +| Speed | 74ms/step (8pcs) | +| Total Time | 195s (8pcs) | +| Loss | 7.7e-4 | +| Params (M) | 103 | +| Checkpoint for inference | 1.2G (.ckpt file) | +| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) | + +### Inference Performance + +| Parameters | Ascend | +| ------------------- | --------------------------- | +| Resource | Ascend 910 | +| Uploaded Date | 12/21/2020 (month/day/year) | +| MindSpore Version | 1.1.0 | +| Dataset | AG's News Topic Classification Dataset | +| batch_size | 128 | +| Total Time | 66s | +| outputs | label index | +| Accuracy | 92.53 | +| Model for inference | 254M (.ckpt file) | + +| Parameters | Ascend | +| ------------------- | --------------------------- | +| Resource | Ascend 910 | +| Uploaded Date | 12/21/2020 (month/day/year) | +| MindSpore Version | 1.1.0 | +| Dataset | DBPedia Ontology Classification Dataset | +| batch_size | 128 | +| Total Time | 164s | +| outputs | label index | +| Accuracy | 98.6 | +| Model for inference | 1.2G (.ckpt file) | + +| Parameters | Ascend | +| ------------------- | --------------------------- | +| Resource | Ascend 910 | +| Uploaded Date | 12/21/2020 (month/day/year) | +| MindSpore Version | 1.1.0 | +| Dataset | Yelp Review Polarity Dataset | +| batch_size | 128 | +| Total Time | 195s | +| outputs | label index | +| Accuracy | 95.7 | +| Model for inference | 1.2G (.ckpt file) | + +# [Random Situation Description](#content) + +There only one random situation. + +- Initialization of some model weights. + +Some seeds have already been set in train.py to avoid the randomness of weight initialization. + +# [Others](#others) + +This model has been validated in the Ascend environment and is not validated on the CPU and GPU. + +# [ModelZoo HomePage](#contents) + + Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo) diff --git a/model_zoo/official/nlp/fasttext/eval.py b/model_zoo/official/nlp/fasttext/eval.py new file mode 100644 index 00000000000..c41963fb810 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/eval.py @@ -0,0 +1,118 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FastText for Evaluation""" +import argparse +import numpy as np + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +import mindspore.ops.operations as P +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as deC +from mindspore import context +from src.fasttext_model import FastText +parser = argparse.ArgumentParser(description='fasttext') +parser.add_argument('--data_path', type=str, help='infer dataset path..') +parser.add_argument('--data_name', type=str, required=True, default='ag', + help='dataset name. eg. ag, dbpedia') +parser.add_argument("--model_ckpt", type=str, required=True, + help="existed checkpoint address.") +args = parser.parse_args() +if args.data_name == "ag": + from src.config import config_ag as config + target_label1 = ['0', '1', '2', '3'] +elif args.data_name == 'dbpedia': + from src.config import config_db as config + target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13'] +elif args.data_name == 'yelp_p': + from src.config import config_yelpp as config + target_label1 = ['0', '1'] +context.set_context( + mode=context.GRAPH_MODE, + save_graphs=False, + device_target="Ascend") + +class FastTextInferCell(nn.Cell): + """ + Encapsulation class of FastText network infer. + + Args: + network (nn.Cell): FastText model. + + Returns: + Tuple[Tensor, Tensor], predicted_ids + """ + def __init__(self, network): + super(FastTextInferCell, self).__init__(auto_prefix=False) + self.network = network + self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True) + self.log_softmax = nn.LogSoftmax(axis=1) + + def construct(self, src_tokens, src_tokens_lengths): + """construct fasttext infer cell""" + prediction = self.network(src_tokens, src_tokens_lengths) + predicted_idx = self.log_softmax(prediction) + predicted_idx, _ = self.argmax(predicted_idx) + + return predicted_idx + +def load_infer_dataset(batch_size, datafile): + """data loader for infer""" + ds = de.MindDataset(datafile, columns_list=['src_tokens', 'src_tokens_length', 'label_idx']) + + type_cast_op = deC.TypeCast(mstype.int32) + ds = ds.map(operations=type_cast_op, input_columns="src_tokens") + ds = ds.map(operations=type_cast_op, input_columns="src_tokens_length") + ds = ds.map(operations=type_cast_op, input_columns="label_idx") + ds = ds.batch(batch_size=batch_size, drop_remainder=True) + + return ds + +def run_fasttext_infer(): + """run infer with FastText""" + dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path) + fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class) + + parameter_dict = load_checkpoint(args.model_ckpt) + load_param_into_net(fasttext_model, parameter_dict=parameter_dict) + + ft_infer = FastTextInferCell(fasttext_model) + + model = Model(ft_infer) + + predictions = [] + target_sens = [] + + for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1): + target_sens.append(batch['label_idx']) + src_tokens = Tensor(batch['src_tokens'], mstype.int32) + src_tokens_length = Tensor(batch['src_tokens_length'], mstype.int32) + predicted_idx = model.predict(src_tokens, src_tokens_length) + predictions.append(predicted_idx.asnumpy()) + + from sklearn.metrics import accuracy_score, classification_report + target_sens = np.array(target_sens).flatten() + predictions = np.array(predictions).flatten() + acc = accuracy_score(target_sens, predictions) + + result_report = classification_report(target_sens, predictions, target_names=target_label1) + print("********Accuracy: ", acc) + print(result_report) + +if __name__ == '__main__': + run_fasttext_infer() diff --git a/model_zoo/official/nlp/fasttext/requirements.txt b/model_zoo/official/nlp/fasttext/requirements.txt new file mode 100644 index 00000000000..6ffd5e16969 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/requirements.txt @@ -0,0 +1,3 @@ +spacy +sklearn +en_core_web_lg diff --git a/model_zoo/official/nlp/fasttext/scripts/create_dataset.sh b/model_zoo/official/nlp/fasttext/scripts/create_dataset.sh new file mode 100644 index 00000000000..b4dd4172b61 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/scripts/create_dataset.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "sh create_dataset.sh SOURCE_DATASET_PATH DATASET_NAME" +echo "for example: sh create_dataset.sh /home/workspace/ag_news_csv ag" +echo "DATASET_NAME including ag, dbpedia, and yelp_p" +echo "It is better to use absolute path." +echo "==============================================================================================================" +ulimit -u unlimited +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +SOURCE_DATASET_PATH=$(get_real_path $1) +DATASET_NAME=$2 + +export DEVICE_NUM=1 +export DEVICE_ID=5 +export RANK_ID=0 +export RANK_SIZE=1 + +if [ $DATASET_NAME == 'ag' ]; +then + echo "Begin to process ag news data" + if [ -d "ag" ]; + then + rm -rf ./ag + fi + mkdir ./ag + cd ./ag || exit + echo "start data preprocess for device $DEVICE_ID" + python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 4 --max_len 467 --bucket [64,128,467] --test_bucket [467] + cd .. +fi + +if [ $DATASET_NAME == 'dbpedia' ]; +then + echo "Begin to process dbpedia data" + if [ -d "dbpedia" ]; + then + rm -rf ./dbpedia + fi + mkdir ./dbpedia + cd ./dbpedia || exit + echo "start data preprocess for device $DEVICE_ID" + python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 14 --max_len 3013 --bucket [128,512,3013] --test_bucket [1120] + cd .. +fi + +if [ $DATASET_NAME == 'yelp_p' ]; +then + echo "Begin to process ag news data" + if [ -d "yelp_p" ]; + then + rm -rf ./yelp_p + fi + mkdir ./yelp_p + cd ./yelp_p || exit + echo "start data preprocess for device $DEVICE_ID" + python ../../src/dataset.py --train_file $SOURCE_DATASET_PATH/train.csv --test_file $SOURCE_DATASET_PATH/test.csv --class_num 2 --max_len 2955 --bucket [64,128,256,512,2955] --test_bucket [2955] + cd .. +fi + + + + diff --git a/model_zoo/official/nlp/fasttext/scripts/run_distribute_train_8p.sh b/model_zoo/official/nlp/fasttext/scripts/run_distribute_train_8p.sh new file mode 100644 index 00000000000..ab5469d4380 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/scripts/run_distribute_train_8p.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "sh run_distributed_train.sh DATASET_PATH RANK_TABLE_PATH" +echo "for example: sh run_distributed_train.sh /home/workspace/ag /home/workspace/rank_table_file.json" +echo "It is better to use absolute path." +echo "==============================================================================================================" +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET=$(get_real_path $1) +echo $DATASET +DATANAME=$(basename $DATASET) +RANK_TABLE_PATH=$(get_real_path $2) +echo $DATANAME +if [ ! -d $DATASET ] +then + echo "Error: DATA_PATH=$DATASET is not a file" +exit 1 +fi +current_exec_path=$(pwd) +echo ${current_exec_path} + +export RANK_TABLE_FILE=$RANK_TABLE_PATH + + +echo $RANK_TABLE_FILE +export RANK_SIZE=8 +export DEVICE_NUM=8 + + +for((i=0;i<=7;i++)); +do + rm -rf ${current_exec_path}/device$i + mkdir ${current_exec_path}/device$i + cd ${current_exec_path}/device$i + cp ../../*.py ./ + cp -r ../../src ./ + cp -r ../*.sh ./ + export RANK_ID=$i + export DEVICE_ID=$i + echo "start training for rank $i, device $DEVICE_ID" + python ../../train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 & + cd ${current_exec_path} +done +cd ${current_exec_path} diff --git a/model_zoo/official/nlp/fasttext/scripts/run_eval.sh b/model_zoo/official/nlp/fasttext/scripts/run_eval.sh new file mode 100644 index 00000000000..beaef1998a3 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/scripts/run_eval.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "sh run_eval.sh DATASET_PATH DATASET_NAME MODEL_CKPT" +echo "for example: sh run_eval.sh /home/workspace/ag/test*.mindrecord ag device0/ckpt0/fasttext-5-118.ckpt" +echo "It is better to use absolute path." +echo "==============================================================================================================" + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET=$(get_real_path $1) +echo $DATASET +DATANAME=$2 +MODEL_CKPT=$(get_real_path $3) +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=5 +export RANK_ID=0 +export RANK_SIZE=1 + + +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval +cp ../*.py ./eval +cp -r ../src ./eval +cp -r ../scripts/*.sh ./eval +cd ./eval || exit +echo "start training for device $DEVICE_ID" +env > env.log +python ../../eval.py --data_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 & +cd .. diff --git a/model_zoo/official/nlp/fasttext/scripts/run_standalone_train.sh b/model_zoo/official/nlp/fasttext/scripts/run_standalone_train.sh new file mode 100644 index 00000000000..d85aa1eaf01 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/scripts/run_standalone_train.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "sh run_standalone_train.sh DATASET_PATH" +echo "for example: sh run_standalone_train.sh /home/workspace/ag" +echo "It is better to use absolute path." +echo "==============================================================================================================" + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET=$(get_real_path $1) +echo $DATASET +DATANAME=$(basename $DATASET) +echo $DATANAME + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp -r ../src ./train +cp -r ../scripts/*.sh ./train +cd ./train || exit +echo "start training for device $DEVICE_ID" +env > env.log +python train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 & +cd .. diff --git a/model_zoo/official/nlp/fasttext/src/config.py b/model_zoo/official/nlp/fasttext/src/config.py new file mode 100644 index 00000000000..38e3e9b55d7 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/src/config.py @@ -0,0 +1,72 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#" :=========================================================================== +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as ed + +config_yelpp = ed({ + 'vocab_size': 6414979, + 'buckets': [64, 128, 256, 512, 2955], + 'batch_size': 128, + 'embedding_dims': 16, + 'num_class': 2, + 'epoch': 5, + 'lr': 0.02, + 'min_lr': 1e-6, + 'decay_steps': 549, + 'warmup_steps': 400000, + 'poly_lr_scheduler_power': 0.5, + 'epoch_count': 1, + 'pretrain_ckpt_dir': None, + 'save_ckpt_steps': 549, + 'keep_ckpt_max': 10, +}) + +config_db = ed({ + 'vocab_size': 6596536, + 'buckets': [128, 512, 3013], + 'batch_size': 128, + 'embedding_dims': 16, + 'num_class': 14, + 'epoch': 5, + 'lr': 0.05, + 'min_lr': 1e-6, + 'decay_steps': 549, + 'warmup_steps': 400000, + 'poly_lr_scheduler_power': 0.5, + 'epoch_count': 1, + 'pretrain_ckpt_dir': None, + 'save_ckpt_steps': 548, + 'keep_ckpt_max': 10, +}) + +config_ag = ed({ + 'vocab_size': 1383812, + 'buckets': [64, 128, 467], + 'batch_size': 128, + 'embedding_dims': 16, + 'num_class': 4, + 'epoch': 5, + 'lr': 0.05, + 'min_lr': 1e-6, + 'decay_steps': 115, + 'warmup_steps': 400000, + 'poly_lr_scheduler_power': 0.5, + 'epoch_count': 1, + 'pretrain_ckpt_dir': None, + 'save_ckpt_steps': 116, + 'keep_ckpt_max': 10, +}) diff --git a/model_zoo/official/nlp/fasttext/src/dataset.py b/model_zoo/official/nlp/fasttext/src/dataset.py new file mode 100644 index 00000000000..359ad665bb6 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/src/dataset.py @@ -0,0 +1,316 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FastText data preprocess""" +import csv +import os +import re +import argparse +import pprint +import ast +import html +import numpy as np +import spacy +from sklearn.feature_extraction import FeatureHasher +from mindspore.mindrecord import FileWriter + +class FastTextDataPreProcess(): + """FastText data preprocess""" + def __init__(self, train_path, + test_file, + max_length, + class_num, + ngram, + train_feature_dict, + buckets, + test_feature_dict, + test_bucket, + is_hashed, + feature_size): + self.train_path = train_path + self.test_path = test_file + self.max_length = max_length + self.class_num = class_num + self.train_feature_dict = train_feature_dict + self.test_feature_dict = test_feature_dict + self.test_bucket = test_bucket + self.is_hashed = is_hashed + self.feature_size = feature_size + self.buckets = buckets + self.ngram = ngram + self.text_greater = '>' + self.text_less = '<' + self.ngram2idx = dict() + self.idx2gram = dict() + self.non_str = '\\' + self.end_string = ['.', '?', '!'] + self.ngram2idx['PAD'] = 0 + self.idx2gram[0] = 'PAD' + self.ngram2idx['UNK'] = 1 + self.idx2gram[1] = 'UNK' + self.str_html = re.compile(r'<[^>]+>') + + def load(self): + """data preprocess loader""" + train_dataset_list = [] + test_dataset_list = [] + spacy_nlp = spacy.load('en_core_web_lg', disable=['parser', 'tagger', 'ner']) + spacy_nlp.add_pipe(spacy_nlp.create_pipe('sentencizer')) + + with open(self.train_path, 'r', newline='', encoding='utf-8') as src_file: + reader = csv.reader(src_file, delimiter=",", quotechar='"') + for _, _pair_sen in enumerate(reader): + label_idx = int(_pair_sen[0]) - 1 + if len(_pair_sen) == 3: + src_tokens = self.input_preprocess(src_text1=_pair_sen[1], + src_text2=_pair_sen[2], + spacy_nlp=spacy_nlp, + train_mode=True) + src_tokens_length = len(src_tokens) + elif len(_pair_sen) == 2: + src_tokens = self.input_preprocess(src_text1=_pair_sen[1], + src_text2=None, + spacy_nlp=spacy_nlp, + train_mode=True) + src_tokens_length = len(src_tokens) + elif len(_pair_sen) == 4: + if _pair_sen[2]: + sen_o_t = _pair_sen[1] + ' ' + _pair_sen[2] + else: + sen_o_t = _pair_sen[1] + src_tokens = self.input_preprocess(src_text1=sen_o_t, + src_text2=_pair_sen[3], + spacy_nlp=spacy_nlp, + train_mode=True) + src_tokens_length = len(src_tokens) + + train_dataset_list.append([src_tokens, src_tokens_length, label_idx]) + + print("Begin to process test data...") + with open(self.test_path, 'r', newline='', encoding='utf-8') as test_file: + reader2 = csv.reader(test_file, delimiter=",", quotechar='"') + for _, _test_sen in enumerate(reader2): + label_idx = int(_test_sen[0]) - 1 + if len(_test_sen) == 3: + src_tokens = self.input_preprocess(src_text1=_test_sen[1], + src_text2=_test_sen[2], + spacy_nlp=spacy_nlp, + train_mode=False) + src_tokens_length = len(src_tokens) + elif len(_test_sen) == 2: + src_tokens = self.input_preprocess(src_text1=_test_sen[1], + src_text2=None, + spacy_nlp=spacy_nlp, + train_mode=False) + src_tokens_length = len(src_tokens) + elif len(_test_sen) == 4: + if _test_sen[2]: + sen_o_t = _test_sen[1] + ' ' + _test_sen[2] + else: + sen_o_t = _test_sen[1] + src_tokens = self.input_preprocess(src_text1=sen_o_t, + src_text2=_test_sen[3], + spacy_nlp=spacy_nlp, + train_mode=False) + src_tokens_length = len(src_tokens) + + test_dataset_list.append([src_tokens, src_tokens_length, label_idx]) + + if self.is_hashed: + print("Begin to Hashing Trick......") + features_num = self.feature_size + fh = FeatureHasher(n_features=features_num, alternate_sign=False) + print("FeatureHasher features..", features_num) + self.hash_trick(fh, train_dataset_list) + self.hash_trick(fh, test_dataset_list) + print("Hashing Done....") + + # pad train dataset + train_dataset_list_length = len(train_dataset_list) + test_dataset_list_length = len(test_dataset_list) + for l in range(train_dataset_list_length): + bucket_length = self._get_bucket_length(train_dataset_list[l][0], self.buckets) + while len(train_dataset_list[l][0]) < bucket_length: + train_dataset_list[l][0].append(self.ngram2idx['PAD']) + train_dataset_list[l][1] = len(train_dataset_list[l][0]) + # pad test dataset + for j in range(test_dataset_list_length): + test_bucket_length = self._get_bucket_length(test_dataset_list[j][0], self.test_bucket) + while len(test_dataset_list[j][0]) < test_bucket_length: + test_dataset_list[j][0].append(self.ngram2idx['PAD']) + test_dataset_list[j][1] = len(test_dataset_list[j][0]) + + train_example_data = [] + test_example_data = [] + for idx in range(train_dataset_list_length): + train_example_data.append({ + "src_tokens": train_dataset_list[idx][0], + "src_tokens_length": train_dataset_list[idx][1], + "label_idx": train_dataset_list[idx][2], + }) + for key in self.train_feature_dict: + if key == train_example_data[idx]['src_tokens_length']: + self.train_feature_dict[key].append(train_example_data[idx]) + for h in range(test_dataset_list_length): + test_example_data.append({ + "src_tokens": test_dataset_list[h][0], + "src_tokens_length": test_dataset_list[h][1], + "label_idx": test_dataset_list[h][2], + }) + for key in self.test_feature_dict: + if key == test_example_data[h]['src_tokens_length']: + self.test_feature_dict[key].append(test_example_data[h]) + print("train vocab size is ", len(self.ngram2idx)) + + return self.train_feature_dict, self.test_feature_dict + + def input_preprocess(self, src_text1, src_text2, spacy_nlp, train_mode): + """data preprocess func""" + src_text1 = src_text1.strip() + if src_text1 and src_text1[-1] not in self.end_string: + src_text1 = src_text1 + '.' + + if src_text2: + src_text2 = src_text2.strip() + sent_describe = src_text1 + ' ' + src_text2 + else: + sent_describe = src_text1 + if self.non_str in sent_describe: + sent_describe = sent_describe.replace(self.non_str, ' ') + + sent_describe = html.unescape(sent_describe) + + if self.text_less in sent_describe and self.text_greater in sent_describe: + sent_describe = self.str_html.sub('', sent_describe) + + + doc = spacy_nlp(sent_describe) + bows_token = [token.text for token in doc] + + try: + tagged_sent_desc = '

' + ' '.join([s.text for s in doc.sents]) + '

' + except ValueError: + tagged_sent_desc = '

' + sent_describe + '

' + doc = spacy_nlp(tagged_sent_desc) + ngrams = self.generate_gram([token.text for token in doc], num=self.ngram) + + bo_ngrams = bows_token + ngrams + + if train_mode is True: + for ngms in bo_ngrams: + idx = self.ngram2idx.get(ngms) + if idx is None: + idx = len(self.ngram2idx) + self.ngram2idx[ngms] = idx + self.idx2gram[idx] = ngms + + processed_out = [self.ngram2idx[ng] if ng in self.ngram2idx else self.ngram2idx['UNK'] for ng in bo_ngrams] + + return processed_out + + def _get_bucket_length(self, x, bts): + x_len = len(x) + for index in range(1, len(bts)): + if bts[index-1] < x_len <= bts[index]: + return bts[index] + return bts[0] + + def generate_gram(self, words, num=2): + + return [' '.join(words[i: i + num]) for i in range(len(words) - num + 1)] + + def count2dict(self, lst): + count_dict = dict() + for m in lst: + if str(m) in count_dict: + count_dict[str(m)] += 1 + else: + count_dict[str(m)] = 1 + return count_dict + + def hash_trick(self, hashing, input_data): + trans = hashing.transform((self.count2dict(e[0]) for e in input_data)) + for htr, e in zip(trans, input_data): + sparse2bow = list() + for idc, d in zip(htr.indices, htr.data): + for _ in range(int(d)): + sparse2bow.append(idc + 1) + e[0] = sparse2bow + + +def write_to_mindrecord(data, path, shared_num=1): + """generate mindrecord""" + if not os.path.isabs(path): + path = os.path.abspath(path) + + writer = FileWriter(path, shared_num) + data_schema = { + "src_tokens": {"type": "int32", "shape": [-1]}, + "src_tokens_length": {"type": "int32", "shape": [-1]}, + "label_idx": {"type": "int32", "shape": [-1]} + } + writer.add_schema(data_schema, "fasttext") + for item in data: + item['src_tokens'] = np.array(item['src_tokens'], dtype=np.int32) + item['src_tokens_length'] = np.array(item['src_tokens_length'], dtype=np.int32) + item['label_idx'] = np.array(item['label_idx'], dtype=np.int32) + writer.write_raw_data([item]) + writer.commit() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--train_file', type=str, required=True, help='train dataset file path') + parser.add_argument('--test_file', type=str, required=True, help='test dataset file path') + parser.add_argument('--class_num', type=int, required=True, help='Dataset classe number') + parser.add_argument('--ngram', type=int, default=2, required=False) + parser.add_argument('--max_len', type=int, required=False, help='max length sentence in dataset') + parser.add_argument('--bucket', type=ast.literal_eval, default=[64, 128, 467], help='bucket sequence length.') + parser.add_argument('--test_bucket', type=ast.literal_eval, default=[467], help='bucket sequence length.') + parser.add_argument('--is_hashed', type=bool, default=False, help='add hash trick for dataset') + parser.add_argument('--feature_size', type=int, default=10000000, help='hash feature size') + + args = parser.parse_args() + pprint.PrettyPrinter().pprint(args.__dict__) + train_feature_dicts = {} + for i in args.bucket: + train_feature_dicts[i] = [] + test_feature_dicts = {} + for i in args.test_bucket: + test_feature_dicts[i] = [] + + g_d = FastTextDataPreProcess(train_path=args.train_file, + test_file=args.test_file, + max_length=args.max_len, + ngram=args.ngram, + class_num=args.class_num, + train_feature_dict=train_feature_dicts, + buckets=args.bucket, + test_feature_dict=test_feature_dicts, + test_bucket=args.test_bucket, + is_hashed=args.is_hashed, + feature_size=args.feature_size) + train_data_example, test_data_example = g_d.load() + print("Data preprocess done") + print("Writing train data to MindRecord file......") + + for i in args.bucket: + write_to_mindrecord(train_data_example[i], './train_dataset_bs_' + str(i) + '.mindrecord', 1) + + print("Writing test data to MindRecord file.....") + for k in args.test_bucket: + + write_to_mindrecord(test_data_example[k], './test_dataset_bs_' + str(k) + '.mindrecord', 1) + + print("All done.....") diff --git a/model_zoo/official/nlp/fasttext/src/fasttext_model.py b/model_zoo/official/nlp/fasttext/src/fasttext_model.py new file mode 100644 index 00000000000..72d49840503 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/src/fasttext_model.py @@ -0,0 +1,70 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FastText model.""" +from mindspore import nn +from mindspore.ops import operations as P +from mindspore.common.initializer import XavierUniform +from mindspore.common import dtype as mstype + +class FastText(nn.Cell): + """ + FastText model + Args: + + vocab_size: vocabulary size + embedding_dims: The size of each embedding vector + num_class: number of labels + """ + def __init__(self, vocab_size, embedding_dims, num_class): + super(FastText, self).__init__() + self.vocab_size = vocab_size + self.embeding_dims = embedding_dims + self.num_class = num_class + self.embeding_func = nn.Embedding(vocab_size=self.vocab_size, + embedding_size=self.embeding_dims, + padding_idx=0, embedding_table='Zeros') + self.fc = nn.Dense(self.embeding_dims, out_channels=self.num_class, + weight_init=XavierUniform(1)).to_float(mstype.float16) + self.reducesum = P.ReduceSum() + self.expand_dims = P.ExpandDims() + self.squeeze = P.Squeeze(axis=1) + self.cast = P.Cast() + self.tile = P.Tile() + self.realdiv = P.RealDiv() + self.fill = P.Fill() + self.log_softmax = nn.LogSoftmax(axis=1) + def construct(self, src_tokens, src_token_length): + """ + construct network + Args: + + src_tokens: source sentences + src_token_length: source sentences length + + Returns: + Tuple[Tensor], network outputs + """ + src_tokens = self.embeding_func(src_tokens) + embeding = self.reducesum(src_tokens, 1) + + length_tiled = self.tile(src_token_length, (1, self.embeding_dims)) + + embeding = self.realdiv(embeding, length_tiled) + + embeding = self.cast(embeding, mstype.float16) + classifer = self.fc(embeding) + classifer = self.cast(classifer, mstype.float32) + + return classifer diff --git a/model_zoo/official/nlp/fasttext/src/fasttext_train.py b/model_zoo/official/nlp/fasttext/src/fasttext_train.py new file mode 100644 index 00000000000..86c0d6fbf04 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/src/fasttext_train.py @@ -0,0 +1,142 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FastText for train""" +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common.parameter import ParameterTuple +from mindspore.context import ParallelMode +from mindspore import nn +from mindspore.communication.management import get_group_size +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore import context +from src.fasttext_model import FastText + + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 + +clip_grad = C.MultitypeFuncGraph("clip_grad") + + +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor], clipped gradients. + """ + if clip_type not in (0, 1): + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + +class FastTextNetWithLoss(nn.Cell): + """ + Provide FastText training loss + + Args: + vocab_size: vocabulary size + embedding_dims: The size of each embedding vector + num_class: number of labels + """ + def __init__(self, vocab_size, embedding_dims, num_class): + super(FastTextNetWithLoss, self).__init__() + self.fasttext = FastText(vocab_size, embedding_dims, num_class) + self.loss_func = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + self.squeeze = P.Squeeze(axis=1) + self.print = P.Print() + + def construct(self, src_tokens, src_tokens_lengths, label_idx): + """ + FastText network with loss. + """ + predict_score = self.fasttext(src_tokens, src_tokens_lengths) + label_idx = self.squeeze(label_idx) + predict_score = self.loss_func(predict_score, label_idx) + + return predict_score + + +class FastTextTrainOneStepCell(nn.Cell): + """ + Encapsulation class of fasttext network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + """ + def __init__(self, network, optimizer, sens=1.0): + super(FastTextTrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode not in ParallelMode.MODE_LIST: + raise ValueError("Parallel mode does not support: ", self.parallel_mode) + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = context.get_auto_parallel_context("gradients_mean") + degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + self.hyper_map = C.HyperMap() + self.cast = P.Cast() + + def set_sens(self, value): + self.sens = value + + def construct(self, + src_token_text, + src_tokens_text_length, + label_idx_tag): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(src_token_text, + src_tokens_text_length, + label_idx_tag) + grads = self.grad(self.network, weights)(src_token_text, + src_tokens_text_length, + label_idx_tag, + self.cast(F.tuple_to_array((self.sens,)), + mstype.float32)) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + + succ = self.optimizer(grads) + return F.depend(loss, succ) diff --git a/model_zoo/official/nlp/fasttext/src/load_dataset.py b/model_zoo/official/nlp/fasttext/src/load_dataset.py new file mode 100644 index 00000000000..07dc4a7692c --- /dev/null +++ b/model_zoo/official/nlp/fasttext/src/load_dataset.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FastText data loader""" +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as deC + +def load_dataset(dataset_path, + batch_size, + epoch_count=1, + rank_size=1, + rank_id=0, + bucket=None, + shuffle=True): + """dataset loader""" + def batch_per_bucket(bucket_length, input_file): + input_file = input_file +'/train_dataset_bs_' + str(bucket_length) + '.mindrecord' + if not input_file: + raise FileNotFoundError("input file parameter must not be empty.") + + ds = de.MindDataset(input_file, + columns_list=['src_tokens', 'src_tokens_length', 'label_idx'], + shuffle=shuffle, + num_shards=rank_size, + shard_id=rank_id, + num_parallel_workers=8) + ori_dataset_size = ds.get_dataset_size() + print(f"Dataset size: {ori_dataset_size}") + repeat_count = epoch_count + type_cast_op = deC.TypeCast(mstype.int32) + ds = ds.map(operations=type_cast_op, input_columns="src_tokens") + ds = ds.map(operations=type_cast_op, input_columns="src_tokens_length") + ds = ds.map(operations=type_cast_op, input_columns="label_idx") + + ds = ds.rename(input_columns=['src_tokens', 'src_tokens_length', 'label_idx'], + output_columns=['src_token_text', 'src_tokens_text_length', 'label_idx_tag']) + ds = ds.batch(batch_size, drop_remainder=False) + ds = ds.repeat(repeat_count) + return ds + for i, _ in enumerate(bucket): + bucket_len = bucket[i] + ds_per = batch_per_bucket(bucket_len, dataset_path) + if i == 0: + ds = ds_per + else: + ds = ds + ds_per + ds = ds.shuffle(ds.get_dataset_size()) + ds.channel_name = 'fasttext' + + return ds diff --git a/model_zoo/official/nlp/fasttext/src/lr_schedule.py b/model_zoo/official/nlp/fasttext/src/lr_schedule.py new file mode 100644 index 00000000000..f0a96128808 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/src/lr_schedule.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Learning rate utilities.""" +from math import ceil +import numpy as np + +def polynomial_decay_scheduler(lr, min_lr, decay_steps, total_update_num, warmup_steps=1000, power=1.0): + """ + Implements of polynomial decay learning rate scheduler which cycles by default. + + Args: + lr (float): Initial learning rate. + warmup_steps (int): Warmup steps. + decay_steps (int): Decay steps. + total_update_num (int): Total update steps. + min_lr (float): Min learning. + power (float): Power factor. + + Returns: + np.ndarray, learning rate of each step. + """ + lrs = np.zeros(shape=total_update_num, dtype=np.float32) + + if decay_steps <= 0: + raise ValueError("`decay_steps` must larger than 1.") + + _start_step = 0 + if 0 < warmup_steps < total_update_num: + warmup_end_lr = lr + warmup_init_lr = 0 if warmup_steps > 0 else warmup_end_lr + lrs[:warmup_steps] = np.linspace(warmup_init_lr, warmup_end_lr, warmup_steps) + _start_step = warmup_steps + + decay_steps = decay_steps + for step in range(_start_step, total_update_num): + _step = step - _start_step + ratio = ceil(_step / decay_steps) + ratio = 1 if ratio < 1 else ratio + _decay_steps = decay_steps * ratio + lrs[step] = (lr - min_lr) * pow(1 - _step / _decay_steps, power) + min_lr + + return lrs diff --git a/model_zoo/official/nlp/fasttext/train.py b/model_zoo/official/nlp/fasttext/train.py new file mode 100644 index 00000000000..1fa7572a968 --- /dev/null +++ b/model_zoo/official/nlp/fasttext/train.py @@ -0,0 +1,202 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FastText for train""" +import os +import time +import argparse +from mindspore import context +from mindspore.nn.optim import Adam +from mindspore.common import set_seed +from mindspore.train.model import Model +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from mindspore.context import ParallelMode +from mindspore.train.callback import Callback, TimeMonitor +from mindspore.communication import management as MultiAscend +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.load_dataset import load_dataset +from src.lr_schedule import polynomial_decay_scheduler +from src.fasttext_train import FastTextTrainOneStepCell, FastTextNetWithLoss + +parser = argparse.ArgumentParser() +parser.add_argument('--data_path', type=str, required=True, help='FastText input data file path.') +parser.add_argument('--data_name', type=str, required=True, default='ag', help='dataset name. eg. ag, dbpedia') +args = parser.parse_args() + +if args.data_name == "ag": + from src.config import config_ag as config +elif args.data_name == 'dbpedia': + from src.config import config_db as config +elif args.data_name == 'yelp_p': + from src.config import config_yelpp as config + +def get_ms_timestamp(): + t = time.time() + return int(round(t * 1000)) +set_seed(5) +time_stamp_init = False +time_stamp_first = 0 +rank_id = os.getenv('DEVICE_ID') +context.set_context( + mode=context.GRAPH_MODE, + save_graphs=False, + device_target="Ascend") + +class LossCallBack(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF terminating training. + + Note: + If per_print_times is 0 do not print loss. + + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, per_print_times=1, rank_ids=0): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self.rank_id = rank_ids + global time_stamp_init, time_stamp_first + if not time_stamp_init: + time_stamp_first = get_ms_timestamp() + time_stamp_init = True + + def step_end(self, run_context): + """Monitor the loss in training.""" + global time_stamp_first + time_stamp_current = get_ms_timestamp() + cb_params = run_context.original_args() + print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - time_stamp_first, + cb_params.cur_epoch_num, + cb_params.cur_step_num, + str(cb_params.net_outputs))) + with open("./loss_{}.log".format(self.rank_id), "a+") as f: + f.write("time: {}, epoch: {}, step: {}, loss: {}".format( + time_stamp_current - time_stamp_first, + cb_params.cur_epoch_num, + cb_params.cur_step_num, + str(cb_params.net_outputs.asnumpy()))) + f.write('\n') + + +def _build_training_pipeline(pre_dataset): + """ + Build training pipeline + + Args: + pre_dataset: preprocessed dataset + """ + net_with_loss = FastTextNetWithLoss(config.vocab_size, config.embedding_dims, config.num_class) + net_with_loss.init_parameters_data() + if config.pretrain_ckpt_dir: + parameter_dict = load_checkpoint(config.pretrain_ckpt_dir) + load_param_into_net(net_with_loss, parameter_dict) + if pre_dataset is None: + raise ValueError("pre-process dataset must be provided") + + #get learning rate + update_steps = config.epoch * pre_dataset.get_dataset_size() + decay_steps = pre_dataset.get_dataset_size() + rank_size = os.getenv("RANK_SIZE") + if isinstance(rank_size, int): + raise ValueError("RANK_SIZE must be integer") + if rank_size is not None and int(rank_size) > 1: + base_lr = config.lr + else: + base_lr = config.lr / 10 + print("+++++++++++Total update steps ", update_steps) + lr = Tensor(polynomial_decay_scheduler(lr=base_lr, + min_lr=config.min_lr, + decay_steps=decay_steps, + total_update_num=update_steps, + warmup_steps=config.warmup_steps, + power=config.poly_lr_scheduler_power), dtype=mstype.float32) + optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.999) + + net_with_grads = FastTextTrainOneStepCell(net_with_loss, optimizer=optimizer) + net_with_grads.set_train(True) + model = Model(net_with_grads) + loss_monitor = LossCallBack(rank_ids=rank_id) + dataset_size = pre_dataset.get_dataset_size() + time_monitor = TimeMonitor(data_size=dataset_size) + ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps, + keep_checkpoint_max=config.keep_ckpt_max) + callbacks = [time_monitor, loss_monitor] + if rank_size is None or int(rank_size) == 1: + ckpt_callback = ModelCheckpoint(prefix='fasttext', + directory=os.path.join('./', 'ckpe_{}'.format(os.getenv("DEVICE_ID"))), + config=ckpt_config) + callbacks.append(ckpt_callback) + if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0: + ckpt_callback = ModelCheckpoint(prefix='fasttext', + directory=os.path.join('./', 'ckpe_{}'.format(os.getenv("DEVICE_ID"))), + config=ckpt_config) + callbacks.append(ckpt_callback) + print("Prepare to Training....") + epoch_size = pre_dataset.get_repeat_count() + print("Epoch size ", epoch_size) + if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1: + print(f" | Rank {MultiAscend.get_rank()} Call model train.") + model.train(epoch=config.epoch, train_dataset=pre_dataset, callbacks=callbacks, dataset_sink_mode=False) + + +def train_single(input_file_path): + """ + Train model on single device + Args: + input_file_path: preprocessed dataset path + """ + print("Staring training on single device.") + preprocessed_data = load_dataset(dataset_path=input_file_path, + batch_size=config.batch_size, + epoch_count=config.epoch_count, + bucket=config.buckets) + _build_training_pipeline(preprocessed_data) + + +def set_parallel_env(): + context.reset_auto_parallel_context() + MultiAscend.init() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + device_num=MultiAscend.get_group_size(), + gradients_mean=True) +def train_paralle(input_file_path): + """ + Train model on multi device + Args: + input_file_path: preprocessed dataset path + """ + set_parallel_env() + print("Starting traning on mutiple devices. |~ _ ~| |~ _ ~| |~ _ ~| |~ _ ~|") + preprocessed_data = load_dataset(dataset_path=input_file_path, + batch_size=config.batch_size, + epoch_count=config.epoch_count, + rank_size=MultiAscend.get_group_size(), + rank_id=MultiAscend.get_rank(), + bucket=config.buckets, + shuffle=False) + _build_training_pipeline(preprocessed_data) + +if __name__ == "__main__": + _rank_size = os.getenv("RANK_SIZE") + if _rank_size is not None and int(_rank_size) > 1: + train_paralle(args.data_path) + else: + train_single(args.data_path)