forked from mindspore-Ecosystem/mindspore
!10318 extract bert embeddings by unified nn.Embedding interface
From: @shibeiji Reviewed-by: Signed-off-by:
This commit is contained in:
commit
a4627c2074
|
@ -166,11 +166,10 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.token_type_vocab_size = token_type_vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[token_type_vocab_size,
|
||||
embedding_size]))
|
||||
|
||||
self.token_type_embedding = nn.Embedding(
|
||||
vocab_size=token_type_vocab_size,
|
||||
embedding_size=embedding_size,
|
||||
use_one_hot=use_one_hot_embeddings)
|
||||
self.shape_flat = (-1,)
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
|
@ -178,35 +177,28 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = tuple(embedding_shape)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.gather = P.GatherV2()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
self.full_position_embeddings = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[max_position_embeddings,
|
||||
embedding_size]))
|
||||
_, seq, _ = self.shape
|
||||
self.full_position_embedding = nn.Embedding(
|
||||
vocab_size=max_position_embeddings,
|
||||
embedding_size=embedding_size,
|
||||
use_one_hot=False)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, token_type_ids, word_embeddings):
|
||||
"""Postprocessors apply positional and token type embeddings to word embeddings."""
|
||||
output = word_embeddings
|
||||
if self.use_token_type:
|
||||
flat_ids = self.reshape(token_type_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids,
|
||||
self.token_type_vocab_size, self.on_value, self.off_value)
|
||||
token_type_embeddings = self.array_mul(one_hot_ids,
|
||||
self.embedding_table)
|
||||
else:
|
||||
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
|
||||
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
|
||||
output += token_type_embeddings
|
||||
token_type_embeddings = self.token_type_embedding(token_type_ids)
|
||||
output = self.add(output, token_type_embeddings)
|
||||
if not self.use_relative_positions:
|
||||
_, seq, width = self.shape
|
||||
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
|
||||
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
|
||||
output += position_embeddings
|
||||
position_embeddings = self.full_position_embedding(self.position_ids)
|
||||
output = self.add(output, position_embeddings)
|
||||
output = self.layernorm(output)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
@ -771,6 +763,7 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
|
|||
def __init__(self, config):
|
||||
super(CreateAttentionMaskFromInputMask, self).__init__()
|
||||
self.input_mask = None
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, 1, config.seq_length)
|
||||
|
@ -808,12 +801,11 @@ class BertModel(nn.Cell):
|
|||
self.last_idx = self.num_hidden_layers - 1
|
||||
output_embedding_shape = [-1, self.seq_length, self.embedding_size]
|
||||
|
||||
self.bert_embedding_lookup = EmbeddingLookup(
|
||||
self.bert_embedding_lookup = nn.Embedding(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=self.embedding_size,
|
||||
embedding_shape=output_embedding_shape,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=config.initializer_range)
|
||||
use_one_hot=use_one_hot_embeddings)
|
||||
self.embedding_tables = self.bert_embedding_lookup.embedding_table
|
||||
|
||||
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
|
||||
embedding_size=self.embedding_size,
|
||||
|
@ -855,7 +847,8 @@ class BertModel(nn.Cell):
|
|||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
"""Bidirectional Encoder Representations from Transformers."""
|
||||
# embedding
|
||||
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
|
||||
embedding_tables = self.embedding_tables
|
||||
word_embeddings = self.bert_embedding_lookup(input_ids)
|
||||
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
|
||||
word_embeddings)
|
||||
|
||||
|
|
|
@ -38,11 +38,11 @@ cfg = edict({
|
|||
'Lamb': edict({
|
||||
'learning_rate': 3e-5,
|
||||
'end_learning_rate': 0.0,
|
||||
'power': 10.0,
|
||||
'power': 5.0,
|
||||
'warmup_steps': 10000,
|
||||
'weight_decay': 0.01,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'eps': 1e-6,
|
||||
'eps': 1e-8,
|
||||
}),
|
||||
'Momentum': edict({
|
||||
'learning_rate': 2e-5,
|
||||
|
|
|
@ -124,16 +124,16 @@ Note: 1.the first run of training will generate the mindrecord file, which will
|
|||
|
||||
```shell
|
||||
# create dataset in mindrecord format
|
||||
bash scripts/convert_dataset_to_mindrecord.sh
|
||||
bash scripts/convert_dataset_to_mindrecord.sh [COCO_DATASET_DIR] [MINDRECORD_DATASET_DIR]
|
||||
|
||||
# standalone training on Ascend
|
||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH]
|
||||
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional)
|
||||
|
||||
# standalone training on CPU
|
||||
bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH]
|
||||
bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional)
|
||||
|
||||
# distributed training on Ascend
|
||||
bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] [RANK_TABLE_FILE]
|
||||
bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE] [LOAD_CHECKPOINT_PATH](optional)
|
||||
|
||||
# eval on Ascend
|
||||
bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH]
|
||||
|
@ -354,7 +354,7 @@ Parameters for optimizer and learning rate:
|
|||
Before your first training, convert coco type dataset to mindrecord files is needed to improve performance on host.
|
||||
|
||||
```bash
|
||||
bash scripts/convert_dataset_to_mindrecord.sh
|
||||
bash scripts/convert_dataset_to_mindrecord.sh /path/coco_dataset_dir /path/mindrecord_dataset_dir
|
||||
```
|
||||
|
||||
The command above will run in the background, after converting mindrecord files will be located in path specified by yourself.
|
||||
|
@ -364,7 +364,7 @@ The command above will run in the background, after converting mindrecord files
|
|||
#### Running on Ascend
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_train_ascend.sh device_id /path/mindrecord_dataset /path/load_ckpt
|
||||
bash scripts/run_standalone_train_ascend.sh device_id /path/mindrecord_dataset /path/load_ckpt(optional)
|
||||
```
|
||||
|
||||
The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows:
|
||||
|
@ -380,7 +380,7 @@ epoch: 349.0, current epoch percent: 1.00, step: 87500, outputs are (Tensor(shap
|
|||
#### Running on CPU
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt
|
||||
bash scripts/run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt(optional)
|
||||
```
|
||||
|
||||
The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows (rusume from pretrained checkpoint and batch_size was set to be 8):
|
||||
|
@ -401,7 +401,7 @@ epoch: 0.0, current epoch percent: 0.00, step: 5, time of per steps: 45.213 s, o
|
|||
#### Running on Ascend
|
||||
|
||||
```bash
|
||||
bash scripts/run_distributed_pretrain_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json
|
||||
bash scripts/run_distributed_pretrain_ascend.sh /path/mindrecord_dataset /path/hccl.json /path/load_ckpt(optional)
|
||||
```
|
||||
|
||||
The command above will run in the background, you can view training logs in LOG*/training_log.txt and LOG*/ms_log/. After training finished, you will get some checkpoint files under the LOG*/ckpt_0 folder by default. The loss value will be displayed as follows:
|
||||
|
|
|
@ -16,13 +16,16 @@
|
|||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash convert_dataset_to_mindrecord.sh"
|
||||
echo "bash convert_dataset_to_mindrecord.sh /path/coco_dataset_dir /path/mindrecord_dataset_dir"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
COCO_DIR=$1
|
||||
MINDRECORD_DIR=$2
|
||||
|
||||
export GLOG_v=1
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
|
||||
python ${PROJECT_DIR}/../src/dataset.py \
|
||||
--coco_data_dir="" \
|
||||
--mindrecord_dir="" \
|
||||
--coco_data_dir=$COCO_DIR \
|
||||
--mindrecord_dir=$MINDRECORD_DIR \
|
||||
--mindrecord_prefix="coco_hp.train.mind" > create_dataset.log 2>&1 &
|
|
@ -16,17 +16,23 @@
|
|||
|
||||
echo "================================================================================================================"
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH RANK_TABLE_FILE"
|
||||
echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json"
|
||||
echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset \"\" /path/hccl.json"
|
||||
echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR RANK_TABLE_FILE LOAD_CHECKPOINT_PATH"
|
||||
echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json /path/load_ckpt"
|
||||
echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "For hyper parameter, please note that you should customize the scripts:
|
||||
'{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' "
|
||||
echo "================================================================================================================"
|
||||
CUR_DIR=`pwd`
|
||||
MINDRECORD_DIR=$1
|
||||
LOAD_CHECKPOINT_PATH=$2
|
||||
HCCL_RANK_FILE=$3
|
||||
HCCL_RANK_FILE=$2
|
||||
if [ $# == 3 ];
|
||||
then
|
||||
LOAD_CHECKPOINT_PATH=$3
|
||||
else
|
||||
LOAD_CHECKPOINT_PATH=""
|
||||
fi
|
||||
|
||||
|
||||
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \
|
||||
--run_script_dir=${CUR_DIR}/train.py \
|
||||
|
|
|
@ -18,12 +18,17 @@ echo "==========================================================================
|
|||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_train_ascend.sh DEVICE_ID MINDRECORD_DIR LOAD_CHECKPOINT_PATH"
|
||||
echo "for example: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset /path/load_ckpt"
|
||||
echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset \"\" "
|
||||
echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
MINDRECORD_DIR=$2
|
||||
LOAD_CHECKPOINT_PATH=$3
|
||||
if [ $# == 3 ];
|
||||
then
|
||||
LOAD_CHECKPOINT_PATH=$3
|
||||
else
|
||||
LOAD_CHECKPOINT_PATH=""
|
||||
fi
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
|
|
|
@ -18,11 +18,17 @@ echo "==========================================================================
|
|||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_train_cpu.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH"
|
||||
echo "for example: bash run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt"
|
||||
echo "if no ckpt, just run: bash run_standalone_train_cpu.sh /path/mindrecord_dataset \"\" "
|
||||
echo "if no ckpt, just run: bash run_standalone_train_cpu.sh /path/mindrecord_dataset"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
MINDRECORD_DIR=$1
|
||||
LOAD_CHECKPOINT_PATH=$2
|
||||
if [ $# == 2 ];
|
||||
then
|
||||
LOAD_CHECKPOINT_PATH=$2
|
||||
echo
|
||||
else
|
||||
LOAD_CHECKPOINT_PATH=""
|
||||
fi
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
|
|
Loading…
Reference in New Issue