From 812b4b0eabace606b079321caf501c82d6141011 Mon Sep 17 00:00:00 2001 From: shibeiji Date: Tue, 22 Dec 2020 10:59:54 +0800 Subject: [PATCH] extract embedding table from unified interface --- model_zoo/official/nlp/bert/src/bert_model.py | 51 ++++++++----------- model_zoo/official/nlp/bert/src/config.py | 4 +- model_zoo/research/cv/centernet/README.md | 16 +++--- .../scripts/convert_dataset_to_mindrecord.sh | 9 ++-- .../scripts/run_distributed_train_ascend.sh | 16 ++++-- .../scripts/run_standalone_train_ascend.sh | 9 +++- .../scripts/run_standalone_train_cpu.sh | 10 +++- 7 files changed, 64 insertions(+), 51 deletions(-) diff --git a/model_zoo/official/nlp/bert/src/bert_model.py b/model_zoo/official/nlp/bert/src/bert_model.py index 77c3ccc7c37..573f5264413 100644 --- a/model_zoo/official/nlp/bert/src/bert_model.py +++ b/model_zoo/official/nlp/bert/src/bert_model.py @@ -166,11 +166,10 @@ class EmbeddingPostprocessor(nn.Cell): self.token_type_vocab_size = token_type_vocab_size self.use_one_hot_embeddings = use_one_hot_embeddings self.max_position_embeddings = max_position_embeddings - self.embedding_table = Parameter(initializer - (TruncatedNormal(initializer_range), - [token_type_vocab_size, - embedding_size])) - + self.token_type_embedding = nn.Embedding( + vocab_size=token_type_vocab_size, + embedding_size=embedding_size, + use_one_hot=use_one_hot_embeddings) self.shape_flat = (-1,) self.one_hot = P.OneHot() self.on_value = Tensor(1.0, mstype.float32) @@ -178,35 +177,28 @@ class EmbeddingPostprocessor(nn.Cell): self.array_mul = P.MatMul() self.reshape = P.Reshape() self.shape = tuple(embedding_shape) - self.layernorm = nn.LayerNorm((embedding_size,)) self.dropout = nn.Dropout(1 - dropout_prob) self.gather = P.GatherV2() self.use_relative_positions = use_relative_positions self.slice = P.StridedSlice() - self.full_position_embeddings = Parameter(initializer - (TruncatedNormal(initializer_range), - [max_position_embeddings, - embedding_size])) + _, seq, _ = self.shape + self.full_position_embedding = nn.Embedding( + vocab_size=max_position_embeddings, + embedding_size=embedding_size, + use_one_hot=False) + self.layernorm = nn.LayerNorm((embedding_size,)) + self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) + self.add = P.TensorAdd() def construct(self, token_type_ids, word_embeddings): """Postprocessors apply positional and token type embeddings to word embeddings.""" output = word_embeddings if self.use_token_type: - flat_ids = self.reshape(token_type_ids, self.shape_flat) - if self.use_one_hot_embeddings: - one_hot_ids = self.one_hot(flat_ids, - self.token_type_vocab_size, self.on_value, self.off_value) - token_type_embeddings = self.array_mul(one_hot_ids, - self.embedding_table) - else: - token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) - token_type_embeddings = self.reshape(token_type_embeddings, self.shape) - output += token_type_embeddings + token_type_embeddings = self.token_type_embedding(token_type_ids) + output = self.add(output, token_type_embeddings) if not self.use_relative_positions: - _, seq, width = self.shape - position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) - position_embeddings = self.reshape(position_embeddings, (1, seq, width)) - output += position_embeddings + position_embeddings = self.full_position_embedding(self.position_ids) + output = self.add(output, position_embeddings) output = self.layernorm(output) output = self.dropout(output) return output @@ -771,6 +763,7 @@ class CreateAttentionMaskFromInputMask(nn.Cell): def __init__(self, config): super(CreateAttentionMaskFromInputMask, self).__init__() self.input_mask = None + self.cast = P.Cast() self.reshape = P.Reshape() self.shape = (-1, 1, config.seq_length) @@ -808,12 +801,11 @@ class BertModel(nn.Cell): self.last_idx = self.num_hidden_layers - 1 output_embedding_shape = [-1, self.seq_length, self.embedding_size] - self.bert_embedding_lookup = EmbeddingLookup( + self.bert_embedding_lookup = nn.Embedding( vocab_size=config.vocab_size, embedding_size=self.embedding_size, - embedding_shape=output_embedding_shape, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=config.initializer_range) + use_one_hot=use_one_hot_embeddings) + self.embedding_tables = self.bert_embedding_lookup.embedding_table self.bert_embedding_postprocessor = EmbeddingPostprocessor( embedding_size=self.embedding_size, @@ -855,7 +847,8 @@ class BertModel(nn.Cell): def construct(self, input_ids, token_type_ids, input_mask): """Bidirectional Encoder Representations from Transformers.""" # embedding - word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) + embedding_tables = self.embedding_tables + word_embeddings = self.bert_embedding_lookup(input_ids) embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings) diff --git a/model_zoo/official/nlp/bert/src/config.py b/model_zoo/official/nlp/bert/src/config.py index e5e8c8b49fc..d27e29522cd 100644 --- a/model_zoo/official/nlp/bert/src/config.py +++ b/model_zoo/official/nlp/bert/src/config.py @@ -38,11 +38,11 @@ cfg = edict({ 'Lamb': edict({ 'learning_rate': 3e-5, 'end_learning_rate': 0.0, - 'power': 10.0, + 'power': 5.0, 'warmup_steps': 10000, 'weight_decay': 0.01, 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), - 'eps': 1e-6, + 'eps': 1e-8, }), 'Momentum': edict({ 'learning_rate': 2e-5, diff --git a/model_zoo/research/cv/centernet/README.md b/model_zoo/research/cv/centernet/README.md index e242912e6c5..2c76af2e881 100644 --- a/model_zoo/research/cv/centernet/README.md +++ b/model_zoo/research/cv/centernet/README.md @@ -124,16 +124,16 @@ Note: 1.the first run of training will generate the mindrecord file, which will ```shell # create dataset in mindrecord format -bash scripts/convert_dataset_to_mindrecord.sh +bash scripts/convert_dataset_to_mindrecord.sh [COCO_DATASET_DIR] [MINDRECORD_DATASET_DIR] # standalone training on Ascend -bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] +bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional) # standalone training on CPU -bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] +bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH](optional) # distributed training on Ascend -bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] [RANK_TABLE_FILE] +bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE] [LOAD_CHECKPOINT_PATH](optional) # eval on Ascend bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH] @@ -354,7 +354,7 @@ Parameters for optimizer and learning rate: Before your first training, convert coco type dataset to mindrecord files is needed to improve performance on host. ```bash -bash scripts/convert_dataset_to_mindrecord.sh +bash scripts/convert_dataset_to_mindrecord.sh /path/coco_dataset_dir /path/mindrecord_dataset_dir ``` The command above will run in the background, after converting mindrecord files will be located in path specified by yourself. @@ -364,7 +364,7 @@ The command above will run in the background, after converting mindrecord files #### Running on Ascend ```bash -bash scripts/run_standalone_train_ascend.sh device_id /path/mindrecord_dataset /path/load_ckpt +bash scripts/run_standalone_train_ascend.sh device_id /path/mindrecord_dataset /path/load_ckpt(optional) ``` The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows: @@ -380,7 +380,7 @@ epoch: 349.0, current epoch percent: 1.00, step: 87500, outputs are (Tensor(shap #### Running on CPU ```bash -bash scripts/run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt +bash scripts/run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt(optional) ``` The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows (rusume from pretrained checkpoint and batch_size was set to be 8): @@ -401,7 +401,7 @@ epoch: 0.0, current epoch percent: 0.00, step: 5, time of per steps: 45.213 s, o #### Running on Ascend ```bash -bash scripts/run_distributed_pretrain_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json +bash scripts/run_distributed_pretrain_ascend.sh /path/mindrecord_dataset /path/hccl.json /path/load_ckpt(optional) ``` The command above will run in the background, you can view training logs in LOG*/training_log.txt and LOG*/ms_log/. After training finished, you will get some checkpoint files under the LOG*/ckpt_0 folder by default. The loss value will be displayed as follows: diff --git a/model_zoo/research/cv/centernet/scripts/convert_dataset_to_mindrecord.sh b/model_zoo/research/cv/centernet/scripts/convert_dataset_to_mindrecord.sh index 1f34b827f25..0258300e508 100644 --- a/model_zoo/research/cv/centernet/scripts/convert_dataset_to_mindrecord.sh +++ b/model_zoo/research/cv/centernet/scripts/convert_dataset_to_mindrecord.sh @@ -16,13 +16,16 @@ echo "==============================================================================================================" echo "Please run the scipt as: " -echo "bash convert_dataset_to_mindrecord.sh" +echo "bash convert_dataset_to_mindrecord.sh /path/coco_dataset_dir /path/mindrecord_dataset_dir" echo "==============================================================================================================" +COCO_DIR=$1 +MINDRECORD_DIR=$2 + export GLOG_v=1 PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) python ${PROJECT_DIR}/../src/dataset.py \ - --coco_data_dir="" \ - --mindrecord_dir="" \ + --coco_data_dir=$COCO_DIR \ + --mindrecord_dir=$MINDRECORD_DIR \ --mindrecord_prefix="coco_hp.train.mind" > create_dataset.log 2>&1 & \ No newline at end of file diff --git a/model_zoo/research/cv/centernet/scripts/run_distributed_train_ascend.sh b/model_zoo/research/cv/centernet/scripts/run_distributed_train_ascend.sh index 08115619db1..06f324178b1 100644 --- a/model_zoo/research/cv/centernet/scripts/run_distributed_train_ascend.sh +++ b/model_zoo/research/cv/centernet/scripts/run_distributed_train_ascend.sh @@ -16,17 +16,23 @@ echo "================================================================================================================" echo "Please run the script as: " -echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH RANK_TABLE_FILE" -echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json" -echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset \"\" /path/hccl.json" +echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR RANK_TABLE_FILE LOAD_CHECKPOINT_PATH" +echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json /path/load_ckpt" +echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/hccl.json" echo "It is better to use the absolute path." echo "For hyper parameter, please note that you should customize the scripts: '{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' " echo "================================================================================================================" CUR_DIR=`pwd` MINDRECORD_DIR=$1 -LOAD_CHECKPOINT_PATH=$2 -HCCL_RANK_FILE=$3 +HCCL_RANK_FILE=$2 +if [ $# == 3 ]; +then + LOAD_CHECKPOINT_PATH=$3 +else + LOAD_CHECKPOINT_PATH="" +fi + python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \ --run_script_dir=${CUR_DIR}/train.py \ diff --git a/model_zoo/research/cv/centernet/scripts/run_standalone_train_ascend.sh b/model_zoo/research/cv/centernet/scripts/run_standalone_train_ascend.sh index b97a8cf2550..51c234a98e7 100644 --- a/model_zoo/research/cv/centernet/scripts/run_standalone_train_ascend.sh +++ b/model_zoo/research/cv/centernet/scripts/run_standalone_train_ascend.sh @@ -18,12 +18,17 @@ echo "========================================================================== echo "Please run the scipt as: " echo "bash run_standalone_train_ascend.sh DEVICE_ID MINDRECORD_DIR LOAD_CHECKPOINT_PATH" echo "for example: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset /path/load_ckpt" -echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset \"\" " +echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset" echo "==============================================================================================================" DEVICE_ID=$1 MINDRECORD_DIR=$2 -LOAD_CHECKPOINT_PATH=$3 +if [ $# == 3 ]; +then + LOAD_CHECKPOINT_PATH=$3 +else + LOAD_CHECKPOINT_PATH="" +fi mkdir -p ms_log PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) diff --git a/model_zoo/research/cv/centernet/scripts/run_standalone_train_cpu.sh b/model_zoo/research/cv/centernet/scripts/run_standalone_train_cpu.sh index d9117f38e14..e8e03ee672b 100644 --- a/model_zoo/research/cv/centernet/scripts/run_standalone_train_cpu.sh +++ b/model_zoo/research/cv/centernet/scripts/run_standalone_train_cpu.sh @@ -18,11 +18,17 @@ echo "========================================================================== echo "Please run the scipt as: " echo "bash run_standalone_train_cpu.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH" echo "for example: bash run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt" -echo "if no ckpt, just run: bash run_standalone_train_cpu.sh /path/mindrecord_dataset \"\" " +echo "if no ckpt, just run: bash run_standalone_train_cpu.sh /path/mindrecord_dataset" echo "==============================================================================================================" MINDRECORD_DIR=$1 -LOAD_CHECKPOINT_PATH=$2 +if [ $# == 2 ]; +then + LOAD_CHECKPOINT_PATH=$2 + echo +else + LOAD_CHECKPOINT_PATH="" +fi mkdir -p ms_log PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)