!10318 extract bert embeddings by unified nn.Embedding interface

From: @shibeiji
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-31 15:36:21 +08:00 committed by Gitee
commit a4627c2074
7 changed files with 64 additions and 51 deletions

View File

@ -166,11 +166,10 @@ class EmbeddingPostprocessor(nn.Cell):
self.token_type_vocab_size = token_type_vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.max_position_embeddings = max_position_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(initializer_range),
[token_type_vocab_size,
embedding_size]))
self.token_type_embedding = nn.Embedding(
vocab_size=token_type_vocab_size,
embedding_size=embedding_size,
use_one_hot=use_one_hot_embeddings)
self.shape_flat = (-1,)
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
@ -178,35 +177,28 @@ class EmbeddingPostprocessor(nn.Cell):
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = tuple(embedding_shape)
self.layernorm = nn.LayerNorm((embedding_size,))
self.dropout = nn.Dropout(1 - dropout_prob)
self.gather = P.GatherV2()
self.use_relative_positions = use_relative_positions
self.slice = P.StridedSlice()
self.full_position_embeddings = Parameter(initializer
(TruncatedNormal(initializer_range),
[max_position_embeddings,
embedding_size]))
_, seq, _ = self.shape
self.full_position_embedding = nn.Embedding(
vocab_size=max_position_embeddings,
embedding_size=embedding_size,
use_one_hot=False)
self.layernorm = nn.LayerNorm((embedding_size,))
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
self.add = P.TensorAdd()
def construct(self, token_type_ids, word_embeddings):
"""Postprocessors apply positional and token type embeddings to word embeddings."""
output = word_embeddings
if self.use_token_type:
flat_ids = self.reshape(token_type_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids,
self.token_type_vocab_size, self.on_value, self.off_value)
token_type_embeddings = self.array_mul(one_hot_ids,
self.embedding_table)
else:
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
output += token_type_embeddings
token_type_embeddings = self.token_type_embedding(token_type_ids)
output = self.add(output, token_type_embeddings)
if not self.use_relative_positions:
_, seq, width = self.shape
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
output += position_embeddings
position_embeddings = self.full_position_embedding(self.position_ids)
output = self.add(output, position_embeddings)
output = self.layernorm(output)
output = self.dropout(output)
return output
@ -771,6 +763,7 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
def __init__(self, config):
super(CreateAttentionMaskFromInputMask, self).__init__()
self.input_mask = None
self.cast = P.Cast()
self.reshape = P.Reshape()
self.shape = (-1, 1, config.seq_length)
@ -808,12 +801,11 @@ class BertModel(nn.Cell):
self.last_idx = self.num_hidden_layers - 1
output_embedding_shape = [-1, self.seq_length, self.embedding_size]
self.bert_embedding_lookup = EmbeddingLookup(
self.bert_embedding_lookup = nn.Embedding(
vocab_size=config.vocab_size,
embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range)
use_one_hot=use_one_hot_embeddings)
self.embedding_tables = self.bert_embedding_lookup.embedding_table
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
@ -855,7 +847,8 @@ class BertModel(nn.Cell):
def construct(self, input_ids, token_type_ids, input_mask):
"""Bidirectional Encoder Representations from Transformers."""
# embedding
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
embedding_tables = self.embedding_tables
word_embeddings = self.bert_embedding_lookup(input_ids)
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
word_embeddings)

View File

@ -38,11 +38,11 @@ cfg = edict({
'Lamb': edict({
'learning_rate': 3e-5,
'end_learning_rate': 0.0,
'power': 10.0,
'power': 5.0,
'warmup_steps': 10000,
'weight_decay': 0.01,
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
'eps': 1e-6,
'eps': 1e-8,
}),
'Momentum': edict({
'learning_rate': 2e-5,

View File

@ -124,16 +124,16 @@ Note: 1.the first run of training will generate the mindrecord file, which will
```shell
# create dataset in mindrecord format
bash scripts/convert_dataset_to_mindrecord.sh
bash scripts/convert_dataset_to_mindrecord.sh [COCO_DATASET_DIR] [MINDRECORD_DATASET_DIR]
# standalone training on Ascend
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH]
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional)
# standalone training on CPU
bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH]
bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional)
# distributed training on Ascend
bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] [RANK_TABLE_FILE]
bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE] [LOAD_CHECKPOINT_PATH](optional)
# eval on Ascend
bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH]
@ -354,7 +354,7 @@ Parameters for optimizer and learning rate:
Before your first training, convert coco type dataset to mindrecord files is needed to improve performance on host.
```bash
bash scripts/convert_dataset_to_mindrecord.sh
bash scripts/convert_dataset_to_mindrecord.sh /path/coco_dataset_dir /path/mindrecord_dataset_dir
```
The command above will run in the background, after converting mindrecord files will be located in path specified by yourself.
@ -364,7 +364,7 @@ The command above will run in the background, after converting mindrecord files
#### Running on Ascend
```bash
bash scripts/run_standalone_train_ascend.sh device_id /path/mindrecord_dataset /path/load_ckpt
bash scripts/run_standalone_train_ascend.sh device_id /path/mindrecord_dataset /path/load_ckpt(optional)
```
The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows:
@ -380,7 +380,7 @@ epoch: 349.0, current epoch percent: 1.00, step: 87500, outputs are (Tensor(shap
#### Running on CPU
```bash
bash scripts/run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt
bash scripts/run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt(optional)
```
The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows (rusume from pretrained checkpoint and batch_size was set to be 8):
@ -401,7 +401,7 @@ epoch: 0.0, current epoch percent: 0.00, step: 5, time of per steps: 45.213 s, o
#### Running on Ascend
```bash
bash scripts/run_distributed_pretrain_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json
bash scripts/run_distributed_pretrain_ascend.sh /path/mindrecord_dataset /path/hccl.json /path/load_ckpt(optional)
```
The command above will run in the background, you can view training logs in LOG*/training_log.txt and LOG*/ms_log/. After training finished, you will get some checkpoint files under the LOG*/ckpt_0 folder by default. The loss value will be displayed as follows:

View File

@ -16,13 +16,16 @@
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash convert_dataset_to_mindrecord.sh"
echo "bash convert_dataset_to_mindrecord.sh /path/coco_dataset_dir /path/mindrecord_dataset_dir"
echo "=============================================================================================================="
COCO_DIR=$1
MINDRECORD_DIR=$2
export GLOG_v=1
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
python ${PROJECT_DIR}/../src/dataset.py \
--coco_data_dir="" \
--mindrecord_dir="" \
--coco_data_dir=$COCO_DIR \
--mindrecord_dir=$MINDRECORD_DIR \
--mindrecord_prefix="coco_hp.train.mind" > create_dataset.log 2>&1 &

View File

@ -16,17 +16,23 @@
echo "================================================================================================================"
echo "Please run the script as: "
echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH RANK_TABLE_FILE"
echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json"
echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset \"\" /path/hccl.json"
echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR RANK_TABLE_FILE LOAD_CHECKPOINT_PATH"
echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json /path/load_ckpt"
echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json"
echo "It is better to use the absolute path."
echo "For hyper parameter, please note that you should customize the scripts:
'{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' "
echo "================================================================================================================"
CUR_DIR=`pwd`
MINDRECORD_DIR=$1
LOAD_CHECKPOINT_PATH=$2
HCCL_RANK_FILE=$3
HCCL_RANK_FILE=$2
if [ $# == 3 ];
then
LOAD_CHECKPOINT_PATH=$3
else
LOAD_CHECKPOINT_PATH=""
fi
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \
--run_script_dir=${CUR_DIR}/train.py \

View File

@ -18,12 +18,17 @@ echo "==========================================================================
echo "Please run the scipt as: "
echo "bash run_standalone_train_ascend.sh DEVICE_ID MINDRECORD_DIR LOAD_CHECKPOINT_PATH"
echo "for example: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset /path/load_ckpt"
echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset \"\" "
echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset"
echo "=============================================================================================================="
DEVICE_ID=$1
MINDRECORD_DIR=$2
LOAD_CHECKPOINT_PATH=$3
if [ $# == 3 ];
then
LOAD_CHECKPOINT_PATH=$3
else
LOAD_CHECKPOINT_PATH=""
fi
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)

View File

@ -18,11 +18,17 @@ echo "==========================================================================
echo "Please run the scipt as: "
echo "bash run_standalone_train_cpu.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH"
echo "for example: bash run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt"
echo "if no ckpt, just run: bash run_standalone_train_cpu.sh /path/mindrecord_dataset \"\" "
echo "if no ckpt, just run: bash run_standalone_train_cpu.sh /path/mindrecord_dataset"
echo "=============================================================================================================="
MINDRECORD_DIR=$1
LOAD_CHECKPOINT_PATH=$2
if [ $# == 2 ];
then
LOAD_CHECKPOINT_PATH=$2
echo
else
LOAD_CHECKPOINT_PATH=""
fi
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)