!12029 Add FastText model for GPU

From: @yuruilee
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-23 17:39:20 +08:00 committed by Gitee
commit ad87aca10d
11 changed files with 434 additions and 166 deletions

View File

@ -1,4 +1,6 @@
![](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png)
# FastText
![mindspore](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png)
<!-- TOC -->
@ -22,7 +24,7 @@
<!-- /TOC -->
# [FastText](#contents)
## [FastText](#contents)
FastText is a fast text classification algorithm, which is simple and efficient. It was proposed by Armand
Joulin, Tomas Mikolov etc. in the article "Bag of Tricks for Efficient Text Classification" in 2016. It is similar to
@ -32,13 +34,13 @@ in various tasks of text classification.
[Paper](https://arxiv.org/pdf/1607.01759.pdf): "Bag of Tricks for Efficient Text Classification", 2016, A. Joulin, E. Grave, P. Bojanowski, and T. Mikolov
# [Model Structure](#contents)
## [Model Structure](#contents)
The FastText model mainly consists of an input layer, hidden layer and output layer, where the input is a sequence of words (text or sentence).
The output layer is probability that the words sequence belongs to different categories. The hidden layer is formed by average of multiple word vector.
The feature is mapped to the hidden layer through linear transformation, and then mapped to the label from the hidden layer.
# [Dataset](#contents)
## [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network
architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
@ -47,17 +49,17 @@ architecture. In the following sections, we will introduce how to run the script
- DBPedia Ontology Classification Dataset
- Yelp Review Polarity Dataset
# [Environment Requirements](#content)
## [Environment Requirements](#content)
- HardwareAscend
- Prepare hardware environment with Ascend processor.
- HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://gitee.com/mindspore/mindspore)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#content)
## [Quick Start](#content)
After dataset preparation, you can start training and evaluation as follows:
@ -73,7 +75,7 @@ sh run_distribute_train.sh [TRAIN_DATASET] [RANK_TABLE_PATH]
sh run_eval.sh [EVAL_DATASET_PATH] [DATASET_NAME] [MODEL_CKPT] [DEVICEID]
```
# [Script Description](#content)
## [Script Description](#content)
The FastText network script and code result are as follows:
@ -91,12 +93,15 @@ The FastText network script and code result are as follows:
│ ├──run_distributed_train.sh // shell script for distributed train on ascend.
│ ├──run_eval.sh // shell script for standalone eval on ascend.
│ ├──run_standalone_train.sh // shell script for standalone eval on ascend.
│ ├──run_distributed_train_gpu.sh // shell script for distributed train on GPU.
│ ├──run_eval_gpu.sh // shell script for standalone eval on GPU.
│ ├──run_standalone_train_gpu.sh // shell script for standalone train on GPU.
├── eval.py // Infer API entry.
├── requirements.txt // Requirements of third party package.
├── train.py // Train API entry.
```
## [Dataset Preparation](#content)
### [Dataset Preparation](#content)
- Download the AG's News Topic Classification Dataset, DBPedia Ontology Classification Dataset and Yelp Review Polarity Dataset. Unzip datasets to any path you want.
@ -107,151 +112,182 @@ The FastText network script and code result are as follows:
sh creat_dataset.sh [SOURCE_DATASET_PATH] [DATASET_NAME]
```
## [Configuration File](#content)
### [Configuration File](#content)
Parameters for both training and evaluation can be set in config.py. All the datasets are using same parameter name, parameters value could be changed according the needs.
- Network Parameters
```text
vocab_size # vocabulary size.
buckets # bucket sequence length.
test_buckets # test dataset bucket sequence length
batch_size # batch size of input dataset.
embedding_dims # The size of each embedding vector.
num_class # number of labels.
epoch # total training epochs.
lr # initial learning rate.
min_lr # minimum learning rate.
warmup_steps # warm up steps.
poly_lr_scheduler_power # a value used to calculate decayed learning rate.
pretrain_ckpt_dir # pretrain checkpoint direction.
keep_ckpt_max # Max ckpt files number.
```
## [Training Process](#content)
- Start task training on a single device and run the shell script
```bash
cd ./scripts
sh run_standalone_train.sh [DATASET_PATH] [DEVICEID]
```text
vocab_size # vocabulary size.
buckets # bucket sequence length.
test_buckets # test dataset bucket sequence length
batch_size # batch size of input dataset.
embedding_dims # The size of each embedding vector.
num_class # number of labels.
epoch # total training epochs.
lr # initial learning rate.
min_lr # minimum learning rate.
warmup_steps # warm up steps.
poly_lr_scheduler_power # a value used to calculate decayed learning rate.
pretrain_ckpt_dir # pretrain checkpoint direction.
keep_ckpt_max # Max ckpt files number.
```
- Running scripts for distributed training of FastText. Task training on multiple device and run the following command in bash to be executed in `scripts/`:
### [Training Process](#content)
``` bash
cd ./scripts
sh run_distributed_train.sh [DATASET_PATH] [RANK_TABLE_PATH]
```
- Running on Ascend
## [Inference Process](#content)
- Start task training on a single device and run the shell script
- Running scripts for evaluation of FastText. The commdan as below.
```bash
cd ./scripts
sh run_standalone_train.sh [DATASET_PATH]
```
``` bash
cd ./scripts
sh run_eval.sh [DATASET_PATH] [DATASET_NAME] [MODEL_CKPT] [DEVICEID]
```
- Running scripts for distributed training of FastText. Task training on multiple device and run the following command in bash to be executed in `scripts/`:
Note: The `DATASET_PATH` is path to mindrecord. eg. /dataset_path/*.mindrecord
```bash
cd ./scripts
sh run_distributed_train.sh [DATASET_PATH] [RANK_TABLE_PATH]
```
# [Model Description](#content)
- Running on GPU
## [Performance](#content)
- Start task training on a single device and run the shell script
### Training Performance
```bash
cd ./scripts
sh run_standalone_train_gpu.sh [DATASET_PATH]
```
| Parameters | Ascend |
| -------------------------- | -------------------------------------------------------------- |
| Resource | Ascend 910; OS Euler2.8 |
| uploaded Date | 12/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | AG's News Topic Classification Dataset |
| Training Parameters | epoch=5, batch_size=512 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Speed | 10ms/step (1pcs) |
| Epoch Time | 2.36s (1pcs) |
| Loss | 0.0067 |
| Params (M) | 22 |
| Checkpoint for inference | 254M (.ckpt file) |
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
- Running scripts for distributed training of FastText. Task training on multiple device and run the following command in bash to be executed in `scripts/`:
| Parameters | Ascend |
| -------------------------- | -------------------------------------------------------------- |
| Resource |Ascend 910; OS Euler2.8 |
| uploaded Date | 11/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | DBPedia Ontology Classification Dataset |
| Training Parameters | epoch=5, batch_size=4096 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Speed | 58ms/step (1pcs) |
| Epoch Time | 8.15s (1pcs) |
| Loss | 2.6e-4 |
| Params (M) | 106 |
| Checkpoint for inference | 1.2G (.ckpt file) |
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
```bash
cd ./scripts
sh run_distributed_train_gpu.sh [DATASET_PATH] [NUM_OF_DEVICES]
```
| Parameters | Ascend |
| -------------------------- | -------------------------------------------------------------- |
| Resource | Ascend 910; OS Euler2.8 |
| uploaded Date | 11/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | Yelp Review Polarity Dataset |
| Training Parameters | epoch=5, batch_size=2048 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| outputs | probability |
| Speed | 101ms/step (1pcs) |
| Epoch Time | 28s (1pcs) |
| Loss | 0.062 |
| Params (M) | 103 |
| Checkpoint for inference | 1.2G (.ckpt file) |
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
### [Inference Process](#content)
### Inference Performance
- Running on Ascend
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Resource | Ascend 910; OS Euler2.8 |
| Uploaded Date | 12/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | AG's News Topic Classification Dataset |
| batch_size | 512 |
| Epoch Time | 2.36s |
| outputs | label index |
| Accuracy | 92.53 |
| Model for inference | 254M (.ckpt file) |
- Running scripts for evaluation of FastText. The commdan as below.
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Resource | Ascend 910; OS Euler2.8 |
| Uploaded Date | 12/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | DBPedia Ontology Classification Dataset |
| batch_size | 4096 |
| Epoch Time | 8.15s |
| outputs | label index |
| Accuracy | 98.6 |
| Model for inference | 1.2G (.ckpt file) |
```bash
cd ./scripts
sh run_eval.sh [DATASET_PATH] [DATASET_NAME] [MODEL_CKPT]
```
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Resource |Ascend 910; OS Euler2.8 |
| Uploaded Date | 12/21/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | Yelp Review Polarity Dataset |
| batch_size | 2048 |
| Epoch Time | 28s |
| outputs | label index |
| Accuracy | 95.7 |
| Model for inference | 1.2G (.ckpt file) |
Note: The `DATASET_PATH` is path to mindrecord. eg. `/dataset_path/*.mindrecord`
# [Random Situation Description](#content)
- Running on GPU
- Running scripts for evaluation of FastText. The commdan as below.
```bash
cd ./scripts
sh run_eval_gpu.sh [DATASET_PATH] [DATASET_NAME] [MODEL_CKPT]
```
Note: The `DATASET_PATH` is path to mindrecord. eg. `/dataset_path/*.mindrecord`
## [Model Description](#content)
### [Performance](#content)
#### Training Performance
| Parameters | Ascend | GPU |
| ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| Resource | Ascend 910; OS Euler2.8 | NV SMX3 V100-32G |
| uploaded Date | 12/21/2020 (month/day/year) | 1/29/2021 (month/day/year) |
| MindSpore Version | 1.1.0 | 1.1.0 |
| Dataset | AG's News Topic Classification Dataset | AG's News Topic Classification Dataset |
| Training Parameters | epoch=5, batch_size=512 | epoch=5, batch_size=512 |
| Optimizer | Adam | Adam |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | probability |
| Speed | 10ms/step (1pcs) | 11.91ms/step(1pcs) |
| Epoch Time | 2.36s (1pcs) | 2.815s(1pcs) |
| Loss | 0.0067 | 0.0085 |
| Params (M) | 22 | 22 |
| Checkpoint for inference | 254M (.ckpt file) | 254M (.ckpt file) |
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
| Parameters | Ascend | GPU |
| ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| Resource | Ascend 910; OS Euler2.8 | NV SMX3 V100-32G |
| uploaded Date | 11/21/2020 (month/day/year) | 1/29/2020 (month/day/year) |
| MindSpore Version | 1.1.0 | 1.1.0 |
| Dataset | DBPedia Ontology Classification Dataset | DBPedia Ontology Classification Dataset |
| Training Parameters | epoch=5, batch_size=4096 | epoch=5, batch_size=4096 |
| Optimizer | Adam | Adam |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | probability |
| Speed | 58ms/step (1pcs) | 34.82ms/step(1pcs) |
| Epoch Time | 8.15s (1pcs) | 4.87s(1pcs) |
| Loss | 2.6e-4 | 0.0004 |
| Params (M) | 106 | 106 |
| Checkpoint for inference | 1.2G (.ckpt file) | 1.2G (.ckpt file) |
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
| Parameters | Ascend | GPU |
| ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| Resource | Ascend 910; OS Euler2.8 | NV SMX3 V100-32G |
| uploaded Date | 11/21/2020 (month/day/year) | 1/29/2020 (month/day/year) |
| MindSpore Version | 1.1.0 | 1.1.0 |
| Dataset | Yelp Review Polarity Dataset | Yelp Review Polarity Dataset |
| Training Parameters | epoch=5, batch_size=2048 | epoch=5, batch_size=2048 |
| Optimizer | Adam | Adam |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | probability |
| Speed | 101ms/step (1pcs) | 30.54ms/step(1pcs) |
| Epoch Time | 28s (1pcs) | 8.46s(1pcs) |
| Loss | 0.062 | 0.002 |
| Params (M) | 103 | 103 |
| Checkpoint for inference | 1.2G (.ckpt file) | 1.2G (.ckpt file) |
| Scripts | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) | [fasttext](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/fasttext) |
#### Inference Performance
| Parameters | Ascend | GPU |
| ------------------- | --------------------------- | ------------------- |
| Resource | Ascend 910; OS Euler2.8 | NV SMX3 V100-32G |
| Uploaded Date | 12/21/2020 (month/day/year) | 1/29/2020 (month/day/year) |
| MindSpore Version | 1.1.0 | 1.1.0 |
| Dataset | AG's News Topic Classification Dataset | AG's News Topic Classification Dataset |
| batch_size | 512 | 128 |
| Epoch Time | 2.36s | 2.815s(1pcs) |
| outputs | label index | label index |
| Accuracy | 92.53 | 92.58 |
| Model for inference | 254M (.ckpt file) | 254M (.ckpt file) |
| Parameters | Ascend | GPU |
| ------------------- | --------------------------- | ------------------- |
| Resource | Ascend 910; OS Euler2.8 | NV SMX3 V100-32G |
| Uploaded Date | 12/21/2020 (month/day/year) | 1/29/2020 (month/day/year) |
| MindSpore Version | 1.1.0 | 1.1.0 |
| Dataset | DBPedia Ontology Classification Dataset | DBPedia Ontology Classification Dataset |
| batch_size | 4096 | 4096 |
| Epoch Time | 8.15s | 4.87s |
| outputs | label index | label index |
| Accuracy | 98.6 | 98.49 |
| Model for inference | 1.2G (.ckpt file) | 1.2G (.ckpt file) |
| Parameters | Ascend | GPU |
| ------------------- | --------------------------- | ------------------- |
| Resource | Ascend 910; OS Euler2.8 | NV SMX3 V100-32G |
| Uploaded Date | 12/21/2020 (month/day/year) | 12/29/2020 (month/day/year) |
| MindSpore Version | 1.1.0 | 1.1.0 |
| Dataset | Yelp Review Polarity Dataset | Yelp Review Polarity Dataset |
| batch_size | 2048 | 2048 |
| Epoch Time | 28s | 8.46s |
| outputs | label index | label index |
| Accuracy | 95.7 | 95.7 |
| Model for inference | 1.2G (.ckpt file) | 1.2G (.ckpt file) |
## [Random Situation Description](#content)
There only one random situation.
@ -259,10 +295,10 @@ There only one random situation.
Some seeds have already been set in train.py to avoid the randomness of weight initialization.
# [Others](#others)
## [Others](#others)
This model has been validated in the Ascend environment and is not validated on the CPU and GPU.
# [ModelZoo HomePage](#contents)
## [ModelZoo HomePage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)

View File

@ -32,21 +32,27 @@ parser.add_argument('--data_name', type=str, required=True, default='ag',
help='dataset name. eg. ag, dbpedia')
parser.add_argument("--model_ckpt", type=str, required=True,
help="existed checkpoint address.")
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
args = parser.parse_args()
if args.data_name == "ag":
from src.config import config_ag as config
from src.config import config_ag_gpu as config_gpu
target_label1 = ['0', '1', '2', '3']
elif args.data_name == 'dbpedia':
from src.config import config_db as config
from src.config import config_db_gpu as config_gpu
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
elif args.data_name == 'yelp_p':
from src.config import config_yelpp as config
from src.config import config_yelpp_gpu as config_gpu
target_label1 = ['0', '1']
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend")
device_target=args.device_target)
config = config_ascend if args.device_target == 'Ascend' else config_gpu
class FastTextInferCell(nn.Cell):
"""
Encapsulation class of FastText network infer.

View File

@ -37,18 +37,22 @@ args = parser.parse_args()
if args.data_name == "ag":
from src.config import config_ag as config
from src.config import config_ag_gpu as config_gpu
target_label1 = ['0', '1', '2', '3']
elif args.data_name == 'dbpedia':
from src.config import config_db as config
from src.config import config_db_gpu as config_gpu
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
elif args.data_name == 'yelp_p':
from src.config import config_yelpp as config
from src.config import config_yelpp_gpu as config_gpu
target_label1 = ['0', '1']
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend")
device_target=args.device_target)
config = config_ascend if args.device_target == 'Ascend' else config_gpu
class FastTextInferExportCell(nn.Cell):
"""
@ -80,16 +84,18 @@ def run_fasttext_export():
parameter_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(fasttext_model, parameter_dict)
ft_infer = FastTextInferExportCell(fasttext_model)
batch_size = config.batch_size
if args.device_target == 'GPU':
batch_size = config.distribute_batch_size
if args.data_name == "ag":
src_tokens_shape = [config.batch_size, 467]
src_tokens_length_shape = [config.batch_size, 1]
src_tokens_shape = [batch_size, 467]
src_tokens_length_shape = [batch_size, 1]
elif args.data_name == 'dbpedia':
src_tokens_shape = [config.batch_size, 1120]
src_tokens_length_shape = [config.batch_size, 1]
src_tokens_shape = [batch_size, 1120]
src_tokens_length_shape = [batch_size, 1]
elif args.data_name == 'yelp_p':
src_tokens_shape = [config.batch_size, 2955]
src_tokens_length_shape = [config.batch_size, 1]
src_tokens_shape = [batch_size, 2955]
src_tokens_length_shape = [batch_size, 1]
file_name = args.file_name + '_' + args.data_name
src_tokens = Tensor(np.ones((src_tokens_shape)).astype(np.int32))

View File

@ -1,3 +1,3 @@
spacy
spacy==2.3.1
sklearn
en_core_web_lg
en_core_web_lg==2.3.1

View File

View File

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_distributed_train_gpu.sh DATASET_PATH DEVICE_NUM"
echo "for example: sh run_distributed_train_gpu.sh /home/workspace/ag 8"
echo "It is better to use absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET=$(get_real_path $1)
echo $DATASET
DATANAME=$(basename $DATASET)
echo $DATANAME
if [ -d "distribute_train" ];
then
rm -rf ./distribute_train
fi
mkdir ./distribute_train
cp ../*.py ./distribute_train
cp -r ../src ./distribute_train
cp -r ../scripts/*.sh ./distribute_train
cd ./distribute_train || exit
echo "start training for $2 GPU devices"
mpirun -n $2 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python ../../train.py --device_target GPU --run_distribute True --data_path $DATASET --data_name $DATANAME
cd ..

View File

@ -0,0 +1,49 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_eval_gpu.sh DATASET_PATH DATASET_NAME MODEL_CKPT"
echo "for example: sh run_eval_gpu.sh /home/workspace/ag/test*.mindrecord ag device0/ckpt0/fasttext-5-118.ckpt"
echo "It is better to use absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET=$(get_real_path $1)
echo $DATASET
DATANAME=$2
MODEL_CKPT=$(get_real_path $3)
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cp -r ../scripts/*.sh ./eval
cd ./eval || exit
echo "start eval on standalone GPU"
python ../../eval.py --device_target GPU --data_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
cd ..

View File

@ -51,5 +51,6 @@ cp -r ../scripts/*.sh ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
#python train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
python train.py --data_path $DATASET --data_name $DATANAME
cd ..

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "sh run_standalone_train_gpu.sh DATASET_PATH"
echo "for example: sh run_standalone_train_gpu.sh /home/workspace/ag"
echo "It is better to use absolute path."
echo "=============================================================================================================="
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET=$(get_real_path $1)
echo $DATASET
DATANAME=$(basename $DATASET)
echo $DATANAME
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cp -r ../scripts/*.sh ./train
cd ./train || exit
echo "start training for standalone GPU device"
python train.py --device_target="GPU" --data_path=$1 --data_name=$DATANAME > log_fasttext.log 2>&1 &
cd ..

View File

@ -73,3 +73,63 @@ config_ag = ed({
'save_ckpt_steps': 116,
'keep_ckpt_max': 10,
})
config_yelpp_gpu = ed({
'vocab_size': 6414979,
'buckets': [64, 128, 256, 512, 2955],
'test_buckets': [64, 128, 256, 512, 2955],
'batch_size': 2048,
'distribute_batch_size': 512,
'embedding_dims': 16,
'num_class': 2,
'epoch': 5,
'lr': 0.30,
'min_lr': 1e-6,
'decay_steps': 549,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.5,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 549,
'keep_ckpt_max': 10,
})
config_db_gpu = ed({
'vocab_size': 6596536,
'buckets': [64, 128, 256, 512, 3013],
'test_buckets': [64, 128, 256, 512, 1120],
'batch_size': 4096,
'distribute_batch_size': 512,
'embedding_dims': 16,
'num_class': 14,
'epoch': 5,
'lr': 0.8,
'min_lr': 1e-6,
'decay_steps': 549,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.5,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 548,
'keep_ckpt_max': 10,
})
config_ag_gpu = ed({
'vocab_size': 1383812,
'buckets': [64, 128, 467],
'test_buckets': [467],
'batch_size': 512,
'distribute_batch_size': 64,
'embedding_dims': 16,
'num_class': 4,
'epoch': 5,
'lr': 0.2,
'min_lr': 1e-6,
'decay_steps': 115,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.001,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 116,
'keep_ckpt_max': 10,
})

View File

@ -16,6 +16,7 @@
import os
import time
import argparse
import ast
from mindspore import context
from mindspore.nn.optim import Adam
from mindspore.common import set_seed
@ -24,7 +25,7 @@ import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.context import ParallelMode
from mindspore.train.callback import Callback, TimeMonitor
from mindspore.communication import management as MultiAscend
from mindspore.communication import management as MultiDevice
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.load_dataset import load_dataset
@ -34,14 +35,21 @@ from src.fasttext_train import FastTextTrainOneStepCell, FastTextNetWithLoss
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, required=True, help='FastText input data file path.')
parser.add_argument('--data_name', type=str, required=True, default='ag', help='dataset name. eg. ag, dbpedia')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute, default: false.')
args = parser.parse_args()
if args.data_name == "ag":
from src.config import config_ag as config
from src.config import config_ag as config_ascend
from src.config import config_ag_gpu as config_gpu
elif args.data_name == 'dbpedia':
from src.config import config_db as config
from src.config import config_db as config_ascend
from src.config import config_db_gpu as config_gpu
elif args.data_name == 'yelp_p':
from src.config import config_yelpp as config
from src.config import config_yelpp as config_ascend
from src.config import config_yelpp_gpu as config_gpu
def get_ms_timestamp():
t = time.time()
@ -53,7 +61,8 @@ rank_id = os.getenv('DEVICE_ID')
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend")
device_target=args.device_target)
config = config_ascend if args.device_target == 'Ascend' else config_gpu
class LossCallBack(Callback):
"""
@ -96,7 +105,7 @@ class LossCallBack(Callback):
f.write('\n')
def _build_training_pipeline(pre_dataset):
def _build_training_pipeline(pre_dataset, run_distribute=False):
"""
Build training pipeline
@ -139,12 +148,12 @@ def _build_training_pipeline(pre_dataset):
ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps * config.epoch,
keep_checkpoint_max=config.keep_ckpt_max)
callbacks = [time_monitor, loss_monitor]
if rank_size is None or int(rank_size) == 1:
if not run_distribute:
ckpt_callback = ModelCheckpoint(prefix='fasttext',
directory=os.path.join('./', 'ckpt_{}'.format(os.getenv("DEVICE_ID"))),
config=ckpt_config)
callbacks.append(ckpt_callback)
if rank_size is not None and int(rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
if run_distribute and MultiDevice.get_rank() % 8 == 0:
ckpt_callback = ModelCheckpoint(prefix='fasttext',
directory=os.path.join('./', 'ckpt_{}'.format(os.getenv("DEVICE_ID"))),
config=ckpt_config)
@ -152,8 +161,8 @@ def _build_training_pipeline(pre_dataset):
print("Prepare to Training....")
epoch_size = pre_dataset.get_repeat_count()
print("Epoch size ", epoch_size)
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1:
print(f" | Rank {MultiAscend.get_rank()} Call model train.")
if run_distribute:
print(f" | Rank {MultiDevice.get_rank()} Call model train.")
model.train(epoch=config.epoch, train_dataset=pre_dataset, callbacks=callbacks, dataset_sink_mode=False)
@ -173,9 +182,9 @@ def train_single(input_file_path):
def set_parallel_env():
context.reset_auto_parallel_context()
MultiAscend.init()
MultiDevice.init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=MultiAscend.get_group_size(),
device_num=MultiDevice.get_group_size(),
gradients_mean=True)
def train_paralle(input_file_path):
"""
@ -185,18 +194,21 @@ def train_paralle(input_file_path):
"""
set_parallel_env()
print("Starting traning on multiple devices. |~ _ ~| |~ _ ~| |~ _ ~| |~ _ ~|")
batch_size = config.batch_size
if args.device_target == 'GPU':
batch_size = config.distribute_batch_size
preprocessed_data = load_dataset(dataset_path=input_file_path,
batch_size=config.batch_size,
batch_size=batch_size,
epoch_count=config.epoch_count,
rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank(),
rank_size=MultiDevice.get_group_size(),
rank_id=MultiDevice.get_rank(),
bucket=config.buckets,
shuffle=False)
_build_training_pipeline(preprocessed_data)
_build_training_pipeline(preprocessed_data, True)
if __name__ == "__main__":
_rank_size = os.getenv("RANK_SIZE")
if _rank_size is not None and int(_rank_size) > 1:
if args.run_distribute:
train_paralle(args.data_path)
else:
train_single(args.data_path)