forked from mindspore-Ecosystem/mindspore
!9016 delete input_mask_from_dataset option in transformer's config
From: @yuchaojie Reviewed-by: @liangchenghui,@linqingke Signed-off-by: @linqingke
This commit is contained in:
commit
a0e4db6aae
|
@ -18,28 +18,24 @@
|
||||||
- [Description of Random Situation](#description-of-random-situation)
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
## [Transfomer Description](#contents)
|
||||||
# [Transfomer Description](#contents)
|
|
||||||
|
|
||||||
Transformer was proposed in 2017 and designed to process sequential data. It is adopted mainly in the field of natural language processing(NLP), for tasks like machine translation or text summarization. Unlike traditional recurrent neural network(RNN) which processes data in order, Transformer adopts attention mechanism and improve the parallelism, therefore reduced training times and made training on larger datasets possible. Since Transformer model was introduced, it has been used to tackle many problems in NLP and derives many network models, such as BERT(Bidirectional Encoder Representations from Transformers) and GPT(Generative Pre-trained Transformer).
|
Transformer was proposed in 2017 and designed to process sequential data. It is adopted mainly in the field of natural language processing(NLP), for tasks like machine translation or text summarization. Unlike traditional recurrent neural network(RNN) which processes data in order, Transformer adopts attention mechanism and improve the parallelism, therefore reduced training times and made training on larger datasets possible. Since Transformer model was introduced, it has been used to tackle many problems in NLP and derives many network models, such as BERT(Bidirectional Encoder Representations from Transformers) and GPT(Generative Pre-trained Transformer).
|
||||||
|
|
||||||
[Paper](https://arxiv.org/abs/1706.03762): Ashish Vaswani, Noam Shazeer, Niki Parmar, JakobUszkoreit, Llion Jones, Aidan N Gomez, Ł ukaszKaiser, and Illia Polosukhin. 2017. Attention is all you need. In NIPS 2017, pages 5998–6008.
|
[Paper](https://arxiv.org/abs/1706.03762): Ashish Vaswani, Noam Shazeer, Niki Parmar, JakobUszkoreit, Llion Jones, Aidan N Gomez, Ł ukaszKaiser, and Illia Polosukhin. 2017. Attention is all you need. In NIPS 2017, pages 5998–6008.
|
||||||
|
|
||||||
|
## [Model Architecture](#contents)
|
||||||
# [Model Architecture](#contents)
|
|
||||||
|
|
||||||
Specifically, Transformer contains six encoder modules and six decoder modules. Each encoder module consists of a self-attention layer and a feed forward layer, each decoder module consists of a self-attention layer, a encoder-decoder-attention layer and a feed forward layer.
|
Specifically, Transformer contains six encoder modules and six decoder modules. Each encoder module consists of a self-attention layer and a feed forward layer, each decoder module consists of a self-attention layer, a encoder-decoder-attention layer and a feed forward layer.
|
||||||
|
|
||||||
|
## [Dataset](#contents)
|
||||||
# [Dataset](#contents)
|
|
||||||
|
|
||||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||||
|
|
||||||
- *WMT Englis-German* for training.
|
- *WMT Englis-German* for training.
|
||||||
- *WMT newstest2014* for evaluation.
|
- *WMT newstest2014* for evaluation.
|
||||||
|
|
||||||
|
## [Environment Requirements](#contents)
|
||||||
# [Environment Requirements](#contents)
|
|
||||||
|
|
||||||
- Hardware(Ascend/GPU)
|
- Hardware(Ascend/GPU)
|
||||||
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||||
|
@ -49,8 +45,7 @@ Note that you can run the scripts based on the dataset mentioned in original pap
|
||||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||||
|
|
||||||
|
## [Quick Start](#contents)
|
||||||
# [Quick Start](#contents)
|
|
||||||
|
|
||||||
After dataset preparation, you can start training and evaluation as follows:
|
After dataset preparation, you can start training and evaluation as follows:
|
||||||
|
|
||||||
|
@ -65,10 +60,9 @@ sh scripts/run_distribute_train_ascend.sh 8 52 /path/ende-l128-mindrecord rank_t
|
||||||
python eval.py > eval.log 2>&1 &
|
python eval.py > eval.log 2>&1 &
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## [Script Description](#contents)
|
||||||
|
|
||||||
# [Script Description](#contents)
|
### [Script and Sample Code](#contents)
|
||||||
|
|
||||||
## [Script and Sample Code](#contents)
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
.
|
.
|
||||||
|
@ -96,10 +90,11 @@ python eval.py > eval.log 2>&1 &
|
||||||
└─train.py
|
└─train.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## [Script Parameters](#contents)
|
### [Script Parameters](#contents)
|
||||||
|
|
||||||
### Training Script Parameters
|
#### Training Script Parameters
|
||||||
```
|
|
||||||
|
```text
|
||||||
usage: train.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N]
|
usage: train.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N]
|
||||||
[--enable_save_ckpt ENABLE_SAVE_CKPT]
|
[--enable_save_ckpt ENABLE_SAVE_CKPT]
|
||||||
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
|
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
|
||||||
|
@ -123,8 +118,9 @@ options:
|
||||||
--bucket_boundaries sequence lengths for different bucket: LIST, default is [16, 32, 48, 64, 128]
|
--bucket_boundaries sequence lengths for different bucket: LIST, default is [16, 32, 48, 64, 128]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Running Options
|
#### Running Options
|
||||||
```
|
|
||||||
|
```text
|
||||||
config.py:
|
config.py:
|
||||||
transformer_network version of Transformer model: base | large, default is large
|
transformer_network version of Transformer model: base | large, default is large
|
||||||
init_loss_scale_value initial value of loss scale: N, default is 2^10
|
init_loss_scale_value initial value of loss scale: N, default is 2^10
|
||||||
|
@ -139,8 +135,9 @@ eval_config.py:
|
||||||
output_file output file of evaluation: PATH
|
output_file output file of evaluation: PATH
|
||||||
```
|
```
|
||||||
|
|
||||||
### Network Parameters
|
#### Network Parameters
|
||||||
```
|
|
||||||
|
```text
|
||||||
Parameters for dataset and network (Training/Evaluation):
|
Parameters for dataset and network (Training/Evaluation):
|
||||||
batch_size batch size of input dataset: N, default is 96
|
batch_size batch size of input dataset: N, default is 96
|
||||||
seq_length max length of input sequence: N, default is 128
|
seq_length max length of input sequence: N, default is 128
|
||||||
|
@ -155,7 +152,6 @@ Parameters for dataset and network (Training/Evaluation):
|
||||||
max_position_embeddings maximum length of sequences: N, default is 128
|
max_position_embeddings maximum length of sequences: N, default is 128
|
||||||
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
|
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
|
||||||
label_smoothing label smoothing setting: Q, default is 0.1
|
label_smoothing label smoothing setting: Q, default is 0.1
|
||||||
input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True
|
|
||||||
beam_width beam width setting: N, default is 4
|
beam_width beam width setting: N, default is 4
|
||||||
max_decode_length max decode length in evaluation: N, default is 80
|
max_decode_length max decode length in evaluation: N, default is 80
|
||||||
length_penalty_weight normalize scores of translations according to their length: Q, default is 1.0
|
length_penalty_weight normalize scores of translations according to their length: Q, default is 1.0
|
||||||
|
@ -169,6 +165,7 @@ Parameters for learning rate:
|
||||||
```
|
```
|
||||||
|
|
||||||
## [Dataset Preparation](#contents)
|
## [Dataset Preparation](#contents)
|
||||||
|
|
||||||
- You may use this [shell script](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files:
|
- You may use this [shell script](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files:
|
||||||
- train.tok.clean.bpe.32000.en
|
- train.tok.clean.bpe.32000.en
|
||||||
- train.tok.clean.bpe.32000.de
|
- train.tok.clean.bpe.32000.de
|
||||||
|
@ -183,6 +180,7 @@ Parameters for learning rate:
|
||||||
paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all
|
paste train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.all
|
||||||
python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128 --bucket [16,32,48,64,128]
|
python create_data.py --input_file train.all --vocab_file vocab.bpe.32000 --output_file /path/ende-l128-mindrecord --max_seq_length 128 --bucket [16,32,48,64,128]
|
||||||
```
|
```
|
||||||
|
|
||||||
- Convert the original data to mindrecord for evaluation:
|
- Convert the original data to mindrecord for evaluation:
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
|
@ -190,7 +188,6 @@ Parameters for learning rate:
|
||||||
python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True --bucket [128]
|
python create_data.py --input_file test.all --vocab_file vocab.bpe.32000 --output_file /path/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True --bucket [128]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## [Training Process](#contents)
|
## [Training Process](#contents)
|
||||||
|
|
||||||
- Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
- Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
||||||
|
@ -200,13 +197,13 @@ Parameters for learning rate:
|
||||||
``` bash
|
``` bash
|
||||||
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH
|
sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH
|
||||||
```
|
```
|
||||||
|
|
||||||
- Run `run_distribute_train_ascend.sh` for distributed training of Transformer model.
|
- Run `run_distribute_train_ascend.sh` for distributed training of Transformer model.
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
sh scripts/run_distribute_train_ascend.sh DEVICE_NUM EPOCH_SIZE DATA_PATH RANK_TABLE_FILE
|
sh scripts/run_distribute_train_ascend.sh DEVICE_NUM EPOCH_SIZE DATA_PATH RANK_TABLE_FILE
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## [Evaluation Process](#contents)
|
## [Evaluation Process](#contents)
|
||||||
|
|
||||||
- Set options in `eval_config.py`. Make sure the 'data_file', 'model_file' and 'output_file' are set to your own path.
|
- Set options in `eval_config.py`. Make sure the 'data_file', 'model_file' and 'output_file' are set to your own path.
|
||||||
|
@ -222,6 +219,7 @@ Parameters for learning rate:
|
||||||
```bash
|
```bash
|
||||||
sh scripts/process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE
|
sh scripts/process_output.sh REF_DATA EVAL_OUTPUT VOCAB_FILE
|
||||||
```
|
```
|
||||||
|
|
||||||
You will get two files, REF_DATA.forbleu and EVAL_OUTPUT.forbleu, for BLEU score calculation.
|
You will get two files, REF_DATA.forbleu and EVAL_OUTPUT.forbleu, for BLEU score calculation.
|
||||||
|
|
||||||
- Calculate BLEU score, you may use this [perl script](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl) and run following command to get the BLEU score.
|
- Calculate BLEU score, you may use this [perl script](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl) and run following command to get the BLEU score.
|
||||||
|
@ -230,10 +228,11 @@ Parameters for learning rate:
|
||||||
perl multi-bleu.perl REF_DATA.forbleu < EVAL_OUTPUT.forbleu
|
perl multi-bleu.perl REF_DATA.forbleu < EVAL_OUTPUT.forbleu
|
||||||
```
|
```
|
||||||
|
|
||||||
# [Model Description](#contents)
|
## [Model Description](#contents)
|
||||||
## [Performance](#contents)
|
|
||||||
|
|
||||||
### Training Performance
|
### [Performance](#contents)
|
||||||
|
|
||||||
|
#### Training Performance
|
||||||
|
|
||||||
| Parameters | Ascend |
|
| Parameters | Ascend |
|
||||||
| -------------------------- | -------------------------------------------------------------- |
|
| -------------------------- | -------------------------------------------------------------- |
|
||||||
|
@ -249,10 +248,9 @@ Parameters for learning rate:
|
||||||
| Loss | 2.8 |
|
| Loss | 2.8 |
|
||||||
| Params (M) | 213.7 |
|
| Params (M) | 213.7 |
|
||||||
| Checkpoint for inference | 2.4G (.ckpt file) |
|
| Checkpoint for inference | 2.4G (.ckpt file) |
|
||||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer |
|
| Scripts | [Transformer scripts](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer) |
|
||||||
|
|
||||||
|
#### Evaluation Performance
|
||||||
### Evaluation Performance
|
|
||||||
|
|
||||||
| Parameters | Ascend |
|
| Parameters | Ascend |
|
||||||
| ------------------- | --------------------------- |
|
| ------------------- | --------------------------- |
|
||||||
|
@ -264,17 +262,16 @@ Parameters for learning rate:
|
||||||
| outputs | BLEU score |
|
| outputs | BLEU score |
|
||||||
| Accuracy | BLEU=28.7 |
|
| Accuracy | BLEU=28.7 |
|
||||||
|
|
||||||
|
## [Description of Random Situation](#contents)
|
||||||
# [Description of Random Situation](#contents)
|
|
||||||
|
|
||||||
There are three random situations:
|
There are three random situations:
|
||||||
|
|
||||||
- Shuffle of the dataset.
|
- Shuffle of the dataset.
|
||||||
- Initialization of some model weights.
|
- Initialization of some model weights.
|
||||||
- Dropout operations.
|
- Dropout operations.
|
||||||
|
|
||||||
Some seeds have already been set in train.py to avoid the randomness of dataset shuffle and weight initialization. If you want to disable dropout, please set the corresponding dropout_prob parameter to 0 in src/config.py.
|
Some seeds have already been set in train.py to avoid the randomness of dataset shuffle and weight initialization. If you want to disable dropout, please set the corresponding dropout_prob parameter to 0 in src/config.py.
|
||||||
|
|
||||||
|
## [ModelZoo Homepage](#contents)
|
||||||
# [ModelZoo Homepage](#contents)
|
|
||||||
|
|
||||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||||
|
|
|
@ -32,7 +32,6 @@ transformer_net_cfg_large = TransformerConfig(
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
label_smoothing=0.1,
|
label_smoothing=0.1,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float16
|
compute_type=mstype.float16
|
||||||
)
|
)
|
||||||
|
|
|
@ -49,7 +49,6 @@ if cfg.transformer_network == 'large':
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
label_smoothing=0.1,
|
label_smoothing=0.1,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float16)
|
compute_type=mstype.float16)
|
||||||
transformer_net_cfg_gpu = TransformerConfig(
|
transformer_net_cfg_gpu = TransformerConfig(
|
||||||
|
@ -66,7 +65,6 @@ if cfg.transformer_network == 'large':
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
label_smoothing=0.1,
|
label_smoothing=0.1,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float16)
|
compute_type=mstype.float16)
|
||||||
if cfg.transformer_network == 'base':
|
if cfg.transformer_network == 'base':
|
||||||
|
@ -84,6 +82,5 @@ if cfg.transformer_network == 'base':
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
label_smoothing=0.1,
|
label_smoothing=0.1,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float16)
|
compute_type=mstype.float16)
|
||||||
|
|
|
@ -41,7 +41,6 @@ if cfg.transformer_network == 'large':
|
||||||
attention_probs_dropout_prob=0.0,
|
attention_probs_dropout_prob=0.0,
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
label_smoothing=0.1,
|
label_smoothing=0.1,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
beam_width=4,
|
beam_width=4,
|
||||||
max_decode_length=80,
|
max_decode_length=80,
|
||||||
length_penalty_weight=1.0,
|
length_penalty_weight=1.0,
|
||||||
|
@ -61,7 +60,6 @@ if cfg.transformer_network == 'base':
|
||||||
attention_probs_dropout_prob=0.0,
|
attention_probs_dropout_prob=0.0,
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
label_smoothing=0.1,
|
label_smoothing=0.1,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
beam_width=4,
|
beam_width=4,
|
||||||
max_decode_length=80,
|
max_decode_length=80,
|
||||||
length_penalty_weight=1.0,
|
length_penalty_weight=1.0,
|
||||||
|
|
|
@ -51,8 +51,6 @@ class TransformerConfig:
|
||||||
model. Default: 128.
|
model. Default: 128.
|
||||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
label_smoothing (float): label smoothing setting. Default: 0.1
|
label_smoothing (float): label smoothing setting. Default: 0.1
|
||||||
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
|
|
||||||
dataset. Default: True.
|
|
||||||
beam_width (int): beam width setting. Default: 4
|
beam_width (int): beam width setting. Default: 4
|
||||||
max_decode_length (int): max decode length in evaluation. Default: 80
|
max_decode_length (int): max decode length in evaluation. Default: 80
|
||||||
length_penalty_weight (float): normalize scores of translations according to their length. Default: 1.0
|
length_penalty_weight (float): normalize scores of translations according to their length. Default: 1.0
|
||||||
|
@ -73,7 +71,6 @@ class TransformerConfig:
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
label_smoothing=0.1,
|
label_smoothing=0.1,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
beam_width=4,
|
beam_width=4,
|
||||||
max_decode_length=80,
|
max_decode_length=80,
|
||||||
length_penalty_weight=1.0,
|
length_penalty_weight=1.0,
|
||||||
|
@ -92,7 +89,6 @@ class TransformerConfig:
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
self.input_mask_from_dataset = input_mask_from_dataset
|
|
||||||
self.beam_width = beam_width
|
self.beam_width = beam_width
|
||||||
self.max_decode_length = max_decode_length
|
self.max_decode_length = max_decode_length
|
||||||
self.length_penalty_weight = length_penalty_weight
|
self.length_penalty_weight = length_penalty_weight
|
||||||
|
@ -1014,7 +1010,6 @@ class TransformerModel(nn.Cell):
|
||||||
config.hidden_dropout_prob = 0.0
|
config.hidden_dropout_prob = 0.0
|
||||||
config.attention_probs_dropout_prob = 0.0
|
config.attention_probs_dropout_prob = 0.0
|
||||||
|
|
||||||
self.input_mask_from_dataset = config.input_mask_from_dataset
|
|
||||||
self.batch_size = config.batch_size
|
self.batch_size = config.batch_size
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|
|
@ -52,7 +52,6 @@ def get_config(version='base', batch_size=1):
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
label_smoothing=0.1,
|
label_smoothing=0.1,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float16)
|
compute_type=mstype.float16)
|
||||||
elif version == 'base':
|
elif version == 'base':
|
||||||
|
@ -70,7 +69,6 @@ def get_config(version='base', batch_size=1):
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
label_smoothing=0.1,
|
label_smoothing=0.1,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float16)
|
compute_type=mstype.float16)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue