diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index b5a0fb3bd8c..dbdc6ea9673 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -308,7 +308,7 @@ def get_bprop_softmax(self): axis = self.axis def bprop(x, out, dout): - dx = mul(sub(dout, sum_func(mul(dout, out), axis)), out) + dx = mul(out, sub(dout, sum_func(mul(out, dout), axis))) return (dx,) return bprop diff --git a/example/bert_clue/README.md b/model_zoo/bert/README.md similarity index 97% rename from example/bert_clue/README.md rename to model_zoo/bert/README.md index 01e0913411b..98a78062ddc 100644 --- a/example/bert_clue/README.md +++ b/model_zoo/bert/README.md @@ -16,12 +16,12 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base]( - Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model. ``` bash - sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR + sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR ``` - Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model. ``` bash - sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH + sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH ``` ### Fine-Tuning diff --git a/example/bert_clue/evaluation.py b/model_zoo/bert/evaluation.py similarity index 96% rename from example/bert_clue/evaluation.py rename to model_zoo/bert/evaluation.py index 2d1086236d1..c58bf836fdf 100644 --- a/example/bert_clue/evaluation.py +++ b/model_zoo/bert/evaluation.py @@ -19,8 +19,6 @@ Bert evaluation script. import os import numpy as np -from evaluation_config import cfg, bert_net_cfg -from utils import BertNER, BertCLS import mindspore.common.dtype as mstype from mindspore import context from mindspore.common.tensor import Tensor @@ -28,9 +26,11 @@ import mindspore.dataset as de import mindspore.dataset.transforms.c_transforms as C from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from CRF import postprocess -from cluener_evaluation import submit -from finetune_config import tag_to_index +from src.evaluation_config import cfg, bert_net_cfg +from src.utils import BertNER, BertCLS +from src.CRF import postprocess +from src.cluener_evaluation import submit +from src.finetune_config import tag_to_index class Accuracy(): ''' diff --git a/example/bert_clue/finetune.py b/model_zoo/bert/finetune.py similarity index 98% rename from example/bert_clue/finetune.py rename to model_zoo/bert/finetune.py index ee62d940b57..6fa08beb73e 100644 --- a/example/bert_clue/finetune.py +++ b/model_zoo/bert/finetune.py @@ -18,8 +18,8 @@ Bert finetune script. ''' import os -from utils import BertFinetuneCell, BertCLS, BertNER -from finetune_config import cfg, bert_net_cfg, tag_to_index +from src.utils import BertFinetuneCell, BertCLS, BertNER +from src.finetune_config import cfg, bert_net_cfg, tag_to_index import mindspore.common.dtype as mstype import mindspore.communication.management as D from mindspore import context diff --git a/example/bert_clue/run_pretrain.py b/model_zoo/bert/run_pretrain.py similarity index 90% rename from example/bert_clue/run_pretrain.py rename to model_zoo/bert/run_pretrain.py index c587d41bc32..1a267b93ffa 100644 --- a/example/bert_clue/run_pretrain.py +++ b/model_zoo/bert/run_pretrain.py @@ -26,10 +26,10 @@ from mindspore.train.parallel_utils import ParallelMode from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR -from dataset import create_bert_dataset -from config import cfg, bert_net_cfg +from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell +from src.dataset import create_bert_dataset +from src.config import cfg, bert_net_cfg _current_dir = os.path.dirname(os.path.realpath(__file__)) class LossCallBack(Callback): @@ -48,10 +48,8 @@ class LossCallBack(Callback): self._per_print_times = per_print_times def step_end(self, run_context): cb_params = run_context.original_args() - with open("./loss.log", "a+") as f: - f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, - str(cb_params.net_outputs))) - f.write('\n') + print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, + str(cb_params.net_outputs))) def run_pretrain(): """pre-train bert_clue""" @@ -81,6 +79,11 @@ def run_pretrain(): context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) + from mindspore.parallel._auto_parallel_context import auto_parallel_context + if bert_net_cfg.num_hidden_layers == 12: + auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) + elif bert_net_cfg.num_hidden_layers == 24: + auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) D.init() rank = args_opt.device_id % device_num else: diff --git a/example/bert_clue/run_distribute_pretrain.sh b/model_zoo/bert/scripts/run_distribute_pretrain.sh similarity index 86% rename from example/bert_clue/run_distribute_pretrain.sh rename to model_zoo/bert/scripts/run_distribute_pretrain.sh index 58ae389a0ec..1d77ff81190 100644 --- a/example/bert_clue/run_distribute_pretrain.sh +++ b/model_zoo/bert/scripts/run_distribute_pretrain.sh @@ -16,8 +16,8 @@ echo "==============================================================================================================" echo "Please run the scipt as: " -echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH" -echo "for example: sh run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json" +echo "bash run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH" +echo "for example: bash run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json" echo "It is better to use absolute path." echo "==============================================================================================================" @@ -49,6 +49,10 @@ do cp *.py ./LOG$i cd ./LOG$i || exit echo "start training for rank $i, device $DEVICE_ID" + mkdir -p ms_log + CUR_DIR=`pwd` + export GLOG_log_dir=${CUR_DIR}/ms_log + export GLOG_logtostderr=0 env > env.log taskset -c $cmdopt python ../run_pretrain.py \ --distribute="true" \ @@ -59,7 +63,7 @@ do --enable_lossscale="true" \ --do_shuffle="true" \ --enable_data_sink="true" \ - --data_sink_steps=1 \ + --data_sink_steps=100 \ --checkpoint_path="" \ --save_checkpoint_steps=10000 \ --save_checkpoint_num=1 \ diff --git a/example/bert_clue/run_standalone_pretrain.sh b/model_zoo/bert/scripts/run_standalone_pretrain.sh similarity index 82% rename from example/bert_clue/run_standalone_pretrain.sh rename to model_zoo/bert/scripts/run_standalone_pretrain.sh index 7795a4e46df..438dda58c35 100644 --- a/example/bert_clue/run_standalone_pretrain.sh +++ b/model_zoo/bert/scripts/run_standalone_pretrain.sh @@ -16,8 +16,8 @@ echo "==============================================================================================================" echo "Please run the scipt as: " -echo "sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR" -echo "for example: sh run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json" +echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR" +echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json" echo "==============================================================================================================" DEVICE_ID=$1 @@ -25,6 +25,10 @@ EPOCH_SIZE=$2 DATA_DIR=$3 SCHEMA_DIR=$4 +mkdir -p ms_log +CUR_DIR=`pwd` +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 python run_pretrain.py \ --distribute="false" \ --epoch_size=$EPOCH_SIZE \ @@ -33,7 +37,7 @@ python run_pretrain.py \ --enable_lossscale="true" \ --do_shuffle="true" \ --enable_data_sink="true" \ - --data_sink_steps=1 \ + --data_sink_steps=100 \ --checkpoint_path="" \ --save_checkpoint_steps=10000 \ --save_checkpoint_num=1 \ diff --git a/example/bert_clue/CRF.py b/model_zoo/bert/src/CRF.py similarity index 100% rename from example/bert_clue/CRF.py rename to model_zoo/bert/src/CRF.py diff --git a/mindspore/model_zoo/Bert_NEZHA/__init__.py b/model_zoo/bert/src/__init__.py similarity index 100% rename from mindspore/model_zoo/Bert_NEZHA/__init__.py rename to model_zoo/bert/src/__init__.py diff --git a/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py b/model_zoo/bert/src/bert_for_pre_training.py similarity index 99% rename from mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py rename to model_zoo/bert/src/bert_for_pre_training.py index 30a1c5290c5..4732cc795fd 100644 --- a/mindspore/model_zoo/Bert_NEZHA/bert_for_pre_training.py +++ b/model_zoo/bert/src/bert_for_pre_training.py @@ -357,10 +357,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: self.reducer_flag = True self.grad_reducer = F.identity + self.degree = 1 if self.reducer_flag: - mean = context.get_auto_parallel_context("mirror_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() @@ -411,10 +411,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): masked_lm_weights, self.cast(scaling_sens, mstype.float32)) - grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) # apply grad reducer on grads grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: diff --git a/mindspore/model_zoo/Bert_NEZHA/bert_model.py b/model_zoo/bert/src/bert_model.py similarity index 95% rename from mindspore/model_zoo/Bert_NEZHA/bert_model.py rename to model_zoo/bert/src/bert_model.py index 899e8f47122..310d330daaa 100644 --- a/mindspore/model_zoo/Bert_NEZHA/bert_model.py +++ b/model_zoo/bert/src/bert_model.py @@ -25,6 +25,7 @@ from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter +from .fused_layer_norm import FusedLayerNorm class BertConfig: @@ -77,7 +78,8 @@ class BertConfig: input_mask_from_dataset=True, token_type_ids_from_dataset=True, dtype=mstype.float32, - compute_type=mstype.float32): + compute_type=mstype.float32, + enable_fused_layernorm=False): self.batch_size = batch_size self.seq_length = seq_length self.vocab_size = vocab_size @@ -96,6 +98,7 @@ class BertConfig: self.use_relative_positions = use_relative_positions self.dtype = dtype self.compute_type = compute_type + self.enable_fused_layernorm = enable_fused_layernorm class EmbeddingLookup(nn.Cell): @@ -240,13 +243,19 @@ class BertOutput(nn.Cell): out_channels, initializer_range=0.02, dropout_prob=0.1, - compute_type=mstype.float32): + compute_type=mstype.float32, + enable_fused_layernorm=False): super(BertOutput, self).__init__() self.dense = nn.Dense(in_channels, out_channels, weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) self.dropout = nn.Dropout(1 - dropout_prob) + self.dropout_prob = dropout_prob self.add = P.TensorAdd() - self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) + if compute_type == mstype.float16: + self.layernorm = FusedLayerNorm((out_channels,), + use_batch_norm=enable_fused_layernorm).to_float(compute_type) + else: + self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) self.cast = P.Cast() def construct(self, hidden_status, input_tensor): @@ -481,12 +490,13 @@ class BertAttention(nn.Cell): self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) self.cast_compute_type = SaturateCast(dst_type=compute_type) - self._generate_relative_positions_embeddings = \ - RelaPosEmbeddingsGenerator(length=to_seq_length, - depth=size_per_head, - max_relative_position=16, - initializer_range=initializer_range, - use_one_hot_embeddings=use_one_hot_embeddings) + if self.use_relative_positions: + self._generate_relative_positions_embeddings = \ + RelaPosEmbeddingsGenerator(length=to_seq_length, + depth=size_per_head, + max_relative_position=16, + initializer_range=initializer_range, + use_one_hot_embeddings=use_one_hot_embeddings) def construct(self, from_tensor, to_tensor, attention_mask): # reshape 2d/3d input tensors to 2d @@ -529,7 +539,7 @@ class BertAttention(nn.Cell): self.trans_shape_position) attention_scores = attention_scores + key_position_scores_r_t - attention_scores = self.multiply(attention_scores, self.scores_mul) + attention_scores = self.multiply(self.scores_mul, attention_scores) if self.has_attention_mask: attention_mask = self.expand_dims(attention_mask, 1) @@ -606,7 +616,8 @@ class BertSelfAttention(nn.Cell): initializer_range=0.02, hidden_dropout_prob=0.1, use_relative_positions=False, - compute_type=mstype.float32): + compute_type=mstype.float32, + enable_fused_layernorm=False): super(BertSelfAttention, self).__init__() if hidden_size % num_attention_heads != 0: raise ValueError("The hidden size (%d) is not a multiple of the number " @@ -634,7 +645,8 @@ class BertSelfAttention(nn.Cell): out_channels=hidden_size, initializer_range=initializer_range, dropout_prob=hidden_dropout_prob, - compute_type=compute_type) + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) self.reshape = P.Reshape() self.shape = (-1, hidden_size) @@ -676,7 +688,8 @@ class BertEncoderCell(nn.Cell): hidden_dropout_prob=0.1, use_relative_positions=False, hidden_act="gelu", - compute_type=mstype.float32): + compute_type=mstype.float32, + enable_fused_layernorm=False): super(BertEncoderCell, self).__init__() self.attention = BertSelfAttention( batch_size=batch_size, @@ -688,7 +701,8 @@ class BertEncoderCell(nn.Cell): initializer_range=initializer_range, hidden_dropout_prob=hidden_dropout_prob, use_relative_positions=use_relative_positions, - compute_type=compute_type) + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) self.intermediate = nn.Dense(in_channels=hidden_size, out_channels=intermediate_size, activation=hidden_act, @@ -697,7 +711,8 @@ class BertEncoderCell(nn.Cell): out_channels=hidden_size, initializer_range=initializer_range, dropout_prob=hidden_dropout_prob, - compute_type=compute_type) + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) def construct(self, hidden_states, attention_mask): # self-attention @@ -744,7 +759,8 @@ class BertTransformer(nn.Cell): use_relative_positions=False, hidden_act="gelu", compute_type=mstype.float32, - return_all_encoders=False): + return_all_encoders=False, + enable_fused_layernorm=False): super(BertTransformer, self).__init__() self.return_all_encoders = return_all_encoders @@ -761,7 +777,8 @@ class BertTransformer(nn.Cell): hidden_dropout_prob=hidden_dropout_prob, use_relative_positions=use_relative_positions, hidden_act=hidden_act, - compute_type=compute_type) + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) layers.append(layer) self.layers = nn.CellList(layers) @@ -888,7 +905,8 @@ class BertModel(nn.Cell): use_relative_positions=config.use_relative_positions, hidden_act=config.hidden_act, compute_type=config.compute_type, - return_all_encoders=True) + return_all_encoders=True, + enable_fused_layernorm=config.enable_fused_layernorm) self.cast = P.Cast() self.dtype = config.dtype diff --git a/example/bert_clue/cluener_evaluation.py b/model_zoo/bert/src/cluener_evaluation.py similarity index 97% rename from example/bert_clue/cluener_evaluation.py rename to model_zoo/bert/src/cluener_evaluation.py index 4f1c98177ba..c2c6770a4ad 100644 --- a/example/bert_clue/cluener_evaluation.py +++ b/model_zoo/bert/src/cluener_evaluation.py @@ -17,12 +17,12 @@ import json import numpy as np -from evaluation_config import cfg import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor -from CRF import postprocess import tokenization from sample_process import label_generation, process_one_example_p +from .evaluation_config import cfg +from .CRF import postprocess vocab_file = "./vocab.txt" tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file) diff --git a/example/bert_clue/config.py b/model_zoo/bert/src/config.py similarity index 76% rename from example/bert_clue/config.py rename to model_zoo/bert/src/config.py index 7cdfcc14f67..d1062b78eec 100644 --- a/example/bert_clue/config.py +++ b/model_zoo/bert/src/config.py @@ -17,16 +17,16 @@ network config setting, will be used in dataset.py, run_pretrain.py """ from easydict import EasyDict as edict import mindspore.common.dtype as mstype -from mindspore.model_zoo.Bert_NEZHA import BertConfig +from .bert_model import BertConfig cfg = edict({ 'bert_network': 'base', - 'loss_scale_value': 2**32, + 'loss_scale_value': 65536, 'scale_factor': 2, 'scale_window': 1000, 'optimizer': 'Lamb', 'AdamWeightDecayDynamicLR': edict({ 'learning_rate': 3e-5, - 'end_learning_rate': 1e-7, + 'end_learning_rate': 1e-10, 'power': 5.0, 'weight_decay': 1e-5, 'eps': 1e-6, @@ -34,7 +34,7 @@ cfg = edict({ }), 'Lamb': edict({ 'start_learning_rate': 3e-5, - 'end_learning_rate': 1e-7, + 'end_learning_rate': 1e-10, 'power': 10.0, 'warmup_steps': 10000, 'weight_decay': 0.01, @@ -56,7 +56,7 @@ if cfg.bert_network == 'base': bert_net_cfg = BertConfig( batch_size=32, seq_length=128, - vocab_size=21128, + vocab_size=21136, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, @@ -71,13 +71,13 @@ if cfg.bert_network == 'base': input_mask_from_dataset=True, token_type_ids_from_dataset=True, dtype=mstype.float32, - compute_type=mstype.float16, + compute_type=mstype.float16 ) if cfg.bert_network == 'nezha': bert_net_cfg = BertConfig( batch_size=32, seq_length=128, - vocab_size=21128, + vocab_size=21136, hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, @@ -92,5 +92,27 @@ if cfg.bert_network == 'nezha': input_mask_from_dataset=True, token_type_ids_from_dataset=True, dtype=mstype.float32, - compute_type=mstype.float16, + compute_type=mstype.float16 + ) +if cfg.bert_network == 'large': + bert_net_cfg = BertConfig( + batch_size=16, + seq_length=512, + vocab_size=30528, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=True ) diff --git a/example/bert_clue/dataset.py b/model_zoo/bert/src/dataset.py similarity index 92% rename from example/bert_clue/dataset.py rename to model_zoo/bert/src/dataset.py index f930b67330d..1828fac4544 100644 --- a/example/bert_clue/dataset.py +++ b/model_zoo/bert/src/dataset.py @@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype import mindspore.dataset.engine.datasets as de import mindspore.dataset.transforms.c_transforms as C from mindspore import log as logger -from config import bert_net_cfg +from .config import bert_net_cfg def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true", @@ -31,8 +31,9 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e files = os.listdir(data_dir) data_files = [] for file_name in files: - data_files.append(os.path.join(data_dir, file_name)) - ds = de.TFRecordDataset(data_files, schema_dir, + if "tfrecord" in file_name: + data_files.append(os.path.join(data_dir, file_name)) + ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, diff --git a/example/bert_clue/evaluation_config.py b/model_zoo/bert/src/evaluation_config.py similarity index 96% rename from example/bert_clue/evaluation_config.py rename to model_zoo/bert/src/evaluation_config.py index ceaaf899692..b18c5643b00 100644 --- a/example/bert_clue/evaluation_config.py +++ b/model_zoo/bert/src/evaluation_config.py @@ -19,7 +19,7 @@ config settings, will be used in finetune.py from easydict import EasyDict as edict import mindspore.common.dtype as mstype -from mindspore.model_zoo.Bert_NEZHA import BertConfig +from .bert_model import BertConfig cfg = edict({ 'task': 'NER', diff --git a/example/bert_clue/finetune_config.py b/model_zoo/bert/src/finetune_config.py similarity index 98% rename from example/bert_clue/finetune_config.py rename to model_zoo/bert/src/finetune_config.py index 8c5f55a62cd..e92842489b9 100644 --- a/example/bert_clue/finetune_config.py +++ b/model_zoo/bert/src/finetune_config.py @@ -19,7 +19,7 @@ config settings, will be used in finetune.py from easydict import EasyDict as edict import mindspore.common.dtype as mstype -from mindspore.model_zoo.Bert_NEZHA import BertConfig +from .bert_model import BertConfig cfg = edict({ 'task': 'NER', diff --git a/model_zoo/bert/src/fused_layer_norm.py b/model_zoo/bert/src/fused_layer_norm.py new file mode 100644 index 00000000000..ee3160b036f --- /dev/null +++ b/model_zoo/bert/src/fused_layer_norm.py @@ -0,0 +1,121 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""fused layernorm""" +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.ops.primitive import constexpr +import mindspore.common.dtype as mstype +from mindspore.nn.cell import Cell + +import numpy as np + + +__all__ = ['FusedLayerNorm'] + +@constexpr +def get_shape_for_norm(x_shape, begin_norm_axis): + print("input_shape: ", x_shape) + norm_shape = x_shape[begin_norm_axis:] + output_shape = (1, -1, 1, int(np.prod(norm_shape))) + print("output_shape: ", output_shape) + return output_shape + +class FusedLayerNorm(Cell): + r""" + Applies Layer Normalization over a mini-batch of inputs. + + Layer normalization is widely used in recurrent neural networks. It applies + normalization over a mini-batch of inputs for each single training case as described + in the paper `Layer Normalization `_. Unlike batch + normalization, layer normalization performs exactly the same computation at training and + testing times. It can be described using the following formula. It is applied across all channels + and pixel but only one batch size. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Args: + normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis + `begin_norm_axis ... R - 1`. + begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions + `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1. + begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters + will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with + the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + use_batch_nrom (bool): Whether use batchnorm to preocess. + + Inputs: + - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`, + and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`. + + Outputs: + Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`. + + Examples: + >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) + >>> shape1 = x.shape()[1:] + >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) + >>> m(x) + """ + def __init__(self, + normalized_shape, + begin_norm_axis=-1, + begin_params_axis=-1, + gamma_init='ones', + beta_init='zeros', + use_batch_norm=False): + super(FusedLayerNorm, self).__init__() + if not isinstance(normalized_shape, (tuple, list)): + raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}." + .format(normalized_shape, type(normalized_shape))) + self.normalized_shape = normalized_shape + self.begin_norm_axis = begin_norm_axis + self.begin_params_axis = begin_params_axis + self.gamma = Parameter(initializer( + gamma_init, normalized_shape), name="gamma") + self.beta = Parameter(initializer( + beta_init, normalized_shape), name="beta") + self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis) + + self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5) + self.use_batch_norm = use_batch_norm + + def construct(self, input_x): + if self.use_batch_norm and self.training: + ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0) + zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0) + shape_x = F.shape(input_x) + norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis) + input_x = F.reshape(input_x, norm_shape) + output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None) + output = F.reshape(output, shape_x) + y = output * self.gamma + self.beta + else: + y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) + return y + + def extend_repr(self): + """Display instance object as string.""" + s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format( + self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) + return s diff --git a/example/bert_clue/sample_process.py b/model_zoo/bert/src/sample_process.py similarity index 100% rename from example/bert_clue/sample_process.py rename to model_zoo/bert/src/sample_process.py diff --git a/example/bert_clue/utils.py b/model_zoo/bert/src/utils.py similarity index 99% rename from example/bert_clue/utils.py rename to model_zoo/bert/src/utils.py index 1d05b957404..e4dd3e7b472 100644 --- a/example/bert_clue/utils.py +++ b/model_zoo/bert/src/utils.py @@ -30,8 +30,8 @@ from mindspore.train.parallel_utils import ParallelMode from mindspore.communication.management import get_group_size from mindspore import context from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel -from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad -from CRF import CRF +from .bert_for_pre_training import clip_grad +from .CRF import CRF GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_VALUE = 1.0 diff --git a/tests/st/networks/models/bert/bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_tdt_lossscale.py index caacd9f16cb..38b207b6a62 100644 --- a/tests/st/networks/models/bert/bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_lossscale.py @@ -25,7 +25,8 @@ import mindspore.dataset.transforms.c_transforms as C from mindspore import context from mindspore import log as logger from mindspore.common.tensor import Tensor -from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell +from src.bert_model import BertConfig +from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell from mindspore.nn.optim import Lamb from mindspore.train.callback import Callback from mindspore.train.loss_scale_manager import DynamicLossScaleManager @@ -77,7 +78,8 @@ def get_config(version='base', batch_size=1): input_mask_from_dataset=True, token_type_ids_from_dataset=True, dtype=mstype.float32, - compute_type=mstype.float16) + compute_type=mstype.float16, + enable_fused_layernorm=False) else: bert_config = BertConfig(batch_size=batch_size) return bert_config diff --git a/tests/st/networks/models/bert/src/CRF.py b/tests/st/networks/models/bert/src/CRF.py new file mode 100644 index 00000000000..6c9fd5ea961 --- /dev/null +++ b/tests/st/networks/models/bert/src/CRF.py @@ -0,0 +1,177 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +CRF script. +''' + +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +import mindspore.common.dtype as mstype + +class CRF(nn.Cell): + ''' + Conditional Random Field + Args: + tag_to_index: The dict for tag to index mapping with extra "" and ""sign. + batch_size: Batch size, i.e., the length of the first dimension. + seq_length: Sequence length, i.e., the length of the second dimention. + is_training: Specifies whether to use training mode. + Returns: + Training mode: Tensor, total loss. + Evaluation mode: Tuple, the index for each step with the highest score; Tuple, the index for the last + step with the highest score. + ''' + def __init__(self, tag_to_index, batch_size=1, seq_length=128, is_training=True): + + super(CRF, self).__init__() + self.target_size = len(tag_to_index) + self.is_training = is_training + self.tag_to_index = tag_to_index + self.batch_size = batch_size + self.seq_length = seq_length + self.START_TAG = "" + self.STOP_TAG = "" + self.START_VALUE = Tensor(self.target_size-2, dtype=mstype.int32) + self.STOP_VALUE = Tensor(self.target_size-1, dtype=mstype.int32) + transitions = np.random.normal(size=(self.target_size, self.target_size)).astype(np.float32) + transitions[tag_to_index[self.START_TAG], :] = -10000 + transitions[:, tag_to_index[self.STOP_TAG]] = -10000 + self.transitions = Parameter(Tensor(transitions), name="transition_matrix") + self.cat = P.Concat(axis=-1) + self.argmax = P.ArgMaxWithValue(axis=-1) + self.log = P.Log() + self.exp = P.Exp() + self.sum = P.ReduceSum() + self.tile = P.Tile() + self.reduce_sum = P.ReduceSum(keep_dims=True) + self.reshape = P.Reshape() + self.expand = P.ExpandDims() + self.mean = P.ReduceMean() + init_alphas = np.ones(shape=(self.batch_size, self.target_size)) * -10000.0 + init_alphas[:, self.tag_to_index[self.START_TAG]] = 0. + self.init_alphas = Tensor(init_alphas, dtype=mstype.float32) + self.cast = P.Cast() + self.reduce_max = P.ReduceMax(keep_dims=True) + self.on_value = Tensor(1.0, dtype=mstype.float32) + self.off_value = Tensor(0.0, dtype=mstype.float32) + self.onehot = P.OneHot() + + def log_sum_exp(self, logits): + ''' + Compute the log_sum_exp score for normalization factor. + ''' + max_score = self.reduce_max(logits, -1) #16 5 5 + score = self.log(self.reduce_sum(self.exp(logits - max_score), -1)) + score = max_score + score + return score + + def _realpath_score(self, features, label): + ''' + Compute the emission and transition score for the real path. + ''' + label = label * 1 + concat_A = self.tile(self.reshape(self.START_VALUE, (1,)), (self.batch_size,)) + concat_A = self.reshape(concat_A, (self.batch_size, 1)) + labels = self.cat((concat_A, label)) + onehot_label = self.onehot(label, self.target_size, self.on_value, self.off_value) + emits = features * onehot_label + labels = self.onehot(labels, self.target_size, self.on_value, self.off_value) + label1 = labels[:, 1:, :] + label2 = labels[:, :self.seq_length, :] + label1 = self.expand(label1, 3) + label2 = self.expand(label2, 2) + label_trans = label1 * label2 + transitions = self.expand(self.expand(self.transitions, 0), 0) + trans = transitions * label_trans + score = self.sum(emits, (1, 2)) + self.sum(trans, (1, 2, 3)) + stop_value_index = labels[:, (self.seq_length-1):self.seq_length, :] + stop_value = self.transitions[(self.target_size-1):self.target_size, :] + stop_score = stop_value * self.reshape(stop_value_index, (self.batch_size, self.target_size)) + score = score + self.sum(stop_score, 1) + score = self.reshape(score, (self.batch_size, -1)) + return score + + def _normalization_factor(self, features): + ''' + Compute the total score for all the paths. + ''' + forward_var = self.init_alphas + forward_var = self.expand(forward_var, 1) + for idx in range(self.seq_length): + feat = features[:, idx:(idx+1), :] + emit_score = self.reshape(feat, (self.batch_size, self.target_size, 1)) + next_tag_var = emit_score + self.transitions + forward_var + forward_var = self.log_sum_exp(next_tag_var) + forward_var = self.reshape(forward_var, (self.batch_size, 1, self.target_size)) + terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1)) + alpha = self.log_sum_exp(terminal_var) + alpha = self.reshape(alpha, (self.batch_size, -1)) + return alpha + + def _decoder(self, features): + ''' + Viterbi decode for evaluation. + ''' + backpointers = () + forward_var = self.init_alphas + for idx in range(self.seq_length): + feat = features[:, idx:(idx+1), :] + feat = self.reshape(feat, (self.batch_size, self.target_size)) + bptrs_t = () + + next_tag_var = self.expand(forward_var, 1) + self.transitions + best_tag_id, best_tag_value = self.argmax(next_tag_var) + bptrs_t += (best_tag_id,) + forward_var = best_tag_value + feat + + backpointers += (bptrs_t,) + terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1)) + best_tag_id, _ = self.argmax(terminal_var) + return backpointers, best_tag_id + + def construct(self, features, label): + if self.is_training: + forward_score = self._normalization_factor(features) + gold_score = self._realpath_score(features, label) + return_value = self.mean(forward_score - gold_score) + else: + path_list, tag = self._decoder(features) + return_value = path_list, tag + return return_value + +def postprocess(backpointers, best_tag_id): + ''' + Do postprocess + ''' + best_tag_id = best_tag_id.asnumpy() + batch_size = len(best_tag_id) + best_path = [] + for i in range(batch_size): + best_path.append([]) + best_local_id = best_tag_id[i] + best_path[-1].append(best_local_id) + for bptrs_t in reversed(backpointers): + bptrs_t = bptrs_t[0].asnumpy() + local_idx = bptrs_t[i] + best_local_id = local_idx[best_local_id] + best_path[-1].append(best_local_id) + # Pop off the start tag (we dont want to return that to the caller) + best_path[-1].pop() + best_path[-1].reverse() + return best_path diff --git a/tests/st/networks/models/bert/src/__init__.py b/tests/st/networks/models/bert/src/__init__.py new file mode 100644 index 00000000000..4f4584a4b48 --- /dev/null +++ b/tests/st/networks/models/bert/src/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert Init.""" +from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ + BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \ + BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell +from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ + BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ + EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ + SaturateCast, CreateAttentionMaskFromInputMask + +__all__ = [ + "BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss", + "GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell", + "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", + "BertSelfAttention", "BertTransformer", "EmbeddingLookup", + "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", + "RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask" +] diff --git a/tests/st/networks/models/bert/src/bert_for_pre_training.py b/tests/st/networks/models/bert/src/bert_for_pre_training.py new file mode 100644 index 00000000000..4732cc795fd --- /dev/null +++ b/tests/st/networks/models/bert/src/bert_for_pre_training.py @@ -0,0 +1,434 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert for pretraining.""" +import numpy as np + +import mindspore.nn as nn +from mindspore.common.initializer import initializer, TruncatedNormal +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.train.parallel_utils import ParallelMode +from mindspore.communication.management import get_group_size +from mindspore import context +from .bert_model import BertModel + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 + +_nn_clip_by_norm = nn.ClipByNorm() +clip_grad = C.MultitypeFuncGraph("clip_grad") +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor], clipped gradients. + """ + if clip_type != 0 and clip_type != 1: + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + +class GetMaskedLMOutput(nn.Cell): + """ + Get masked lm output. + + Args: + config (BertConfig): The config of BertModel. + + Returns: + Tensor, masked lm output. + """ + def __init__(self, config): + super(GetMaskedLMOutput, self).__init__() + self.width = config.hidden_size + self.reshape = P.Reshape() + self.gather = P.GatherV2() + + weight_init = TruncatedNormal(config.initializer_range) + self.dense = nn.Dense(self.width, + config.hidden_size, + weight_init=weight_init, + activation=config.hidden_act).to_float(config.compute_type) + self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type) + self.output_bias = Parameter( + initializer( + 'zero', + config.vocab_size), + name='output_bias') + self.matmul = P.MatMul(transpose_b=True) + self.log_softmax = nn.LogSoftmax(axis=-1) + self.shape_flat_offsets = (-1, 1) + self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32)) + self.last_idx = (-1,) + self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width) + self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32)) + self.cast = P.Cast() + self.compute_type = config.compute_type + self.dtype = config.dtype + + def construct(self, + input_tensor, + output_weights, + positions): + flat_offsets = self.reshape( + self.rng * self.seq_length_tensor, self.shape_flat_offsets) + flat_position = self.reshape(positions + flat_offsets, self.last_idx) + flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) + input_tensor = self.gather(flat_sequence_tensor, flat_position, 0) + input_tensor = self.cast(input_tensor, self.compute_type) + output_weights = self.cast(output_weights, self.compute_type) + input_tensor = self.dense(input_tensor) + input_tensor = self.layernorm(input_tensor) + logits = self.matmul(input_tensor, output_weights) + logits = self.cast(logits, self.dtype) + logits = logits + self.output_bias + log_probs = self.log_softmax(logits) + return log_probs + + +class GetNextSentenceOutput(nn.Cell): + """ + Get next sentence output. + + Args: + config (BertConfig): The config of Bert. + + Returns: + Tensor, next sentence output. + """ + def __init__(self, config): + super(GetNextSentenceOutput, self).__init__() + self.log_softmax = P.LogSoftmax() + self.weight_init = TruncatedNormal(config.initializer_range) + self.dense = nn.Dense(config.hidden_size, 2, + weight_init=self.weight_init, has_bias=True).to_float(config.compute_type) + self.dtype = config.dtype + self.cast = P.Cast() + + def construct(self, input_tensor): + logits = self.dense(input_tensor) + logits = self.cast(logits, self.dtype) + log_prob = self.log_softmax(logits) + return log_prob + + +class BertPreTraining(nn.Cell): + """ + Bert pretraining network. + + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. + + Returns: + Tensor, prediction_scores, seq_relationship_score. + """ + def __init__(self, config, is_training, use_one_hot_embeddings): + super(BertPreTraining, self).__init__() + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cls1 = GetMaskedLMOutput(config) + self.cls2 = GetNextSentenceOutput(config) + + def construct(self, input_ids, input_mask, token_type_id, + masked_lm_positions): + sequence_output, pooled_output, embedding_table = \ + self.bert(input_ids, token_type_id, input_mask) + prediction_scores = self.cls1(sequence_output, + embedding_table, + masked_lm_positions) + seq_relationship_score = self.cls2(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPretrainingLoss(nn.Cell): + """ + Provide bert pre-training loss. + + Args: + config (BertConfig): The config of BertModel. + + Returns: + Tensor, total loss. + """ + def __init__(self, config): + super(BertPretrainingLoss, self).__init__() + self.vocab_size = config.vocab_size + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.neg = P.Neg() + self.cast = P.Cast() + + def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids, + masked_lm_weights, next_sentence_labels): + """Defines the computation performed.""" + label_ids = self.reshape(masked_lm_ids, self.last_idx) + label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32) + one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value) + + per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx)) + numerator = self.reduce_sum(label_weights * per_example_loss, ()) + denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32) + masked_lm_loss = numerator / denominator + + # next_sentence_loss + labels = self.reshape(next_sentence_labels, self.last_idx) + one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value) + per_example_loss = self.neg(self.reduce_sum( + one_hot_labels * seq_relationship_score, self.last_idx)) + next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx) + + # total_loss + total_loss = masked_lm_loss + next_sentence_loss + + return total_loss + + +class BertNetworkWithLoss(nn.Cell): + """ + Provide bert pre-training loss through network. + + Args: + config (BertConfig): The config of BertModel. + is_training (bool): Specifies whether to use the training mode. + use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False. + + Returns: + Tensor, the loss of the network. + """ + def __init__(self, config, is_training, use_one_hot_embeddings=False): + super(BertNetworkWithLoss, self).__init__() + self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings) + self.loss = BertPretrainingLoss(config) + self.cast = P.Cast() + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + prediction_scores, seq_relationship_score = \ + self.bert(input_ids, input_mask, token_type_id, masked_lm_positions) + total_loss = self.loss(prediction_scores, seq_relationship_score, + masked_lm_ids, masked_lm_weights, next_sentence_labels) + return self.cast(total_loss, mstype.float32) + + +class BertTrainOneStepCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + """ + def __init__(self, network, optimizer, sens=1.0): + super(BertTrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + self.cast = P.Cast() + self.hyper_map = C.HyperMap() + + def set_sens(self, value): + self.sens = value + + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights): + """Defines the computation performed.""" + weights = self.weights + + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(F.tuple_to_array((self.sens,)), + mstype.float32)) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + + succ = self.optimizer(grads) + return F.depend(loss, succ) + + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + + +class BertTrainOneStepWithLossScaleCell(nn.Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + self.add_flags(has_effect=True) + def construct(self, + input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + next_sentence_labels, + masked_lm_positions, + masked_lm_ids, + masked_lm_weights, + self.cast(scaling_sens, + mstype.float32)) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond, scaling_sens) + return F.depend(ret, succ) diff --git a/tests/st/networks/models/bert/src/bert_model.py b/tests/st/networks/models/bert/src/bert_model.py new file mode 100644 index 00000000000..310d330daaa --- /dev/null +++ b/tests/st/networks/models/bert/src/bert_model.py @@ -0,0 +1,949 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bert model.""" + +import math +import copy +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops.functional as F +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from .fused_layer_norm import FusedLayerNorm + + +class BertConfig: + """ + Configuration for `BertModel`. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. Default: 128. + vocab_size (int): The shape of each embedding vector. Default: 32000. + hidden_size (int): Size of the bert encoder layers. Default: 768. + num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder + cell. Default: 12. + num_attention_heads (int): Number of attention heads in the BertTransformer + encoder cell. Default: 12. + intermediate_size (int): Size of intermediate layer in the BertTransformer + encoder cell. Default: 3072. + hidden_act (str): Activation function used in the BertTransformer encoder + cell. Default: "gelu". + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + type_vocab_size (int): Size of token type vocab. Default: 16. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from + dataset. Default: True. + token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded + from dataset. Default: True. + dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length=128, + vocab_size=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float32, + enable_fused_layernorm=False): + self.batch_size = batch_size + self.seq_length = seq_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.input_mask_from_dataset = input_mask_from_dataset + self.token_type_ids_from_dataset = token_type_ids_from_dataset + self.use_relative_positions = use_relative_positions + self.dtype = dtype + self.compute_type = compute_type + self.enable_fused_layernorm = enable_fused_layernorm + + +class EmbeddingLookup(nn.Cell): + """ + A embeddings lookup table with a fixed dictionary and size. + + Args: + vocab_size (int): Size of the dictionary of embeddings. + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + """ + def __init__(self, + vocab_size, + embedding_size, + embedding_shape, + use_one_hot_embeddings=False, + initializer_range=0.02): + super(EmbeddingLookup, self).__init__() + self.vocab_size = vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [vocab_size, embedding_size]), + name='embedding_table') + self.expand = P.ExpandDims() + self.shape_flat = (-1,) + self.gather = P.GatherV2() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + + def construct(self, input_ids): + extended_ids = self.expand(input_ids, -1) + flat_ids = self.reshape(extended_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul( + one_hot_ids, self.embedding_table) + else: + output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) + output = self.reshape(output_for_reshape, self.shape) + return output, self.embedding_table + + +class EmbeddingPostprocessor(nn.Cell): + """ + Postprocessors apply positional and token type embeddings to word embeddings. + + Args: + embedding_size (int): The size of each embedding vector. + embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of + each embedding vector. + use_token_type (bool): Specifies whether to use token type embeddings. Default: False. + token_type_vocab_size (int): Size of token type vocab. Default: 16. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + max_position_embeddings (int): Maximum length of sequences used in this + model. Default: 512. + dropout_prob (float): The dropout probability. Default: 0.1. + """ + def __init__(self, + embedding_size, + embedding_shape, + use_relative_positions=False, + use_token_type=False, + token_type_vocab_size=16, + use_one_hot_embeddings=False, + initializer_range=0.02, + max_position_embeddings=512, + dropout_prob=0.1): + super(EmbeddingPostprocessor, self).__init__() + self.use_token_type = use_token_type + self.token_type_vocab_size = token_type_vocab_size + self.use_one_hot_embeddings = use_one_hot_embeddings + self.max_position_embeddings = max_position_embeddings + self.embedding_table = Parameter(initializer + (TruncatedNormal(initializer_range), + [token_type_vocab_size, + embedding_size]), + name='embedding_table') + + self.shape_flat = (-1,) + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.1, mstype.float32) + self.array_mul = P.MatMul() + self.reshape = P.Reshape() + self.shape = tuple(embedding_shape) + self.layernorm = nn.LayerNorm((embedding_size,)) + self.dropout = nn.Dropout(1 - dropout_prob) + self.gather = P.GatherV2() + self.use_relative_positions = use_relative_positions + self.slice = P.StridedSlice() + self.full_position_embeddings = Parameter(initializer + (TruncatedNormal(initializer_range), + [max_position_embeddings, + embedding_size]), + name='full_position_embeddings') + + def construct(self, token_type_ids, word_embeddings): + output = word_embeddings + if self.use_token_type: + flat_ids = self.reshape(token_type_ids, self.shape_flat) + if self.use_one_hot_embeddings: + one_hot_ids = self.one_hot(flat_ids, + self.token_type_vocab_size, self.on_value, self.off_value) + token_type_embeddings = self.array_mul(one_hot_ids, + self.embedding_table) + else: + token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) + token_type_embeddings = self.reshape(token_type_embeddings, self.shape) + output += token_type_embeddings + if not self.use_relative_positions: + _, seq, width = self.shape + position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) + position_embeddings = self.reshape(position_embeddings, (1, seq, width)) + output += position_embeddings + output = self.layernorm(output) + output = self.dropout(output) + return output + + +class BertOutput(nn.Cell): + """ + Apply a linear computation to hidden status and a residual computation to input. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + dropout_prob (float): The dropout probability. Default: 0.1. + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + """ + def __init__(self, + in_channels, + out_channels, + initializer_range=0.02, + dropout_prob=0.1, + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertOutput, self).__init__() + self.dense = nn.Dense(in_channels, out_channels, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.dropout = nn.Dropout(1 - dropout_prob) + self.dropout_prob = dropout_prob + self.add = P.TensorAdd() + if compute_type == mstype.float16: + self.layernorm = FusedLayerNorm((out_channels,), + use_batch_norm=enable_fused_layernorm).to_float(compute_type) + else: + self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) + self.cast = P.Cast() + + def construct(self, hidden_status, input_tensor): + output = self.dense(hidden_status) + output = self.dropout(output) + output = self.add(output, input_tensor) + output = self.layernorm(output) + return output + + +class RelaPosMatrixGenerator(nn.Cell): + """ + Generates matrix of relative positions between inputs. + + Args: + length (int): Length of one dim for the matrix to be generated. + max_relative_position (int): Max value of relative position. + """ + def __init__(self, length, max_relative_position): + super(RelaPosMatrixGenerator, self).__init__() + self._length = length + self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32) + self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32) + self.range_length = -length + 1 + + self.tile = P.Tile() + self.range_mat = P.Reshape() + self.sub = P.Sub() + self.expanddims = P.ExpandDims() + self.cast = P.Cast() + + def construct(self): + range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32) + range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1)) + tile_row_out = self.tile(range_vec_row_out, (self._length,)) + tile_col_out = self.tile(range_vec_col_out, (1, self._length)) + range_mat_out = self.range_mat(tile_row_out, (self._length, self._length)) + transpose_out = self.range_mat(tile_col_out, (self._length, self._length)) + distance_mat = self.sub(range_mat_out, transpose_out) + + distance_mat_clipped = C.clip_by_value(distance_mat, + self._min_relative_position, + self._max_relative_position) + + # Shift values to be >=0. Each integer still uniquely identifies a + # relative position difference. + final_mat = distance_mat_clipped + self._max_relative_position + return final_mat + + +class RelaPosEmbeddingsGenerator(nn.Cell): + """ + Generates tensor of size [length, length, depth]. + + Args: + length (int): Length of one dim for the matrix to be generated. + depth (int): Size of each attention head. + max_relative_position (int): Maxmum value of relative position. + initializer_range (float): Initialization value of TruncatedNormal. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + length, + depth, + max_relative_position, + initializer_range, + use_one_hot_embeddings=False): + super(RelaPosEmbeddingsGenerator, self).__init__() + self.depth = depth + self.vocab_size = max_relative_position * 2 + 1 + self.use_one_hot_embeddings = use_one_hot_embeddings + + self.embeddings_table = Parameter( + initializer(TruncatedNormal(initializer_range), + [self.vocab_size, self.depth]), + name='embeddings_for_position') + + self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, + max_relative_position=max_relative_position) + self.reshape = P.Reshape() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.shape = P.Shape() + self.gather = P.GatherV2() # index_select + self.matmul = P.BatchMatMul() + + def construct(self): + relative_positions_matrix_out = self.relative_positions_matrix() + + # Generate embedding for each relative position of dimension depth. + if self.use_one_hot_embeddings: + flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) + one_hot_relative_positions_matrix = self.one_hot( + flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value) + embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) + my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) + embeddings = self.reshape(embeddings, my_shape) + else: + embeddings = self.gather(self.embeddings_table, + relative_positions_matrix_out, 0) + return embeddings + + +class SaturateCast(nn.Cell): + """ + Performs a safe saturating cast. This operation applies proper clamping before casting to prevent + the danger that the value will overflow or underflow. + + Args: + src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. + dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. + """ + def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): + super(SaturateCast, self).__init__() + np_type = mstype.dtype_to_nptype(dst_type) + min_type = np.finfo(np_type).min + max_type = np.finfo(np_type).max + + self.tensor_min_type = Tensor([min_type], dtype=src_type) + self.tensor_max_type = Tensor([max_type], dtype=src_type) + + self.min_op = P.Minimum() + self.max_op = P.Maximum() + self.cast = P.Cast() + self.dst_type = dst_type + + def construct(self, x): + out = self.max_op(x, self.tensor_min_type) + out = self.min_op(out, self.tensor_max_type) + return self.cast(out, self.dst_type) + + +class BertAttention(nn.Cell): + """ + Apply multi-headed attention from "from_tensor" to "to_tensor". + + Args: + batch_size (int): Batch size of input datasets. + from_tensor_width (int): Size of last dim of from_tensor. + to_tensor_width (int): Size of last dim of to_tensor. + from_seq_length (int): Length of from_tensor sequence. + to_seq_length (int): Length of to_tensor sequence. + num_attention_heads (int): Number of attention heads. Default: 1. + size_per_head (int): Size of each attention head. Default: 512. + query_act (str): Activation function for the query transform. Default: None. + key_act (str): Activation function for the key transform. Default: None. + value_act (str): Activation function for the value transform. Default: None. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.0. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d + tensor. Default: False. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + from_tensor_width, + to_tensor_width, + from_seq_length, + to_seq_length, + num_attention_heads=1, + size_per_head=512, + query_act=None, + key_act=None, + value_act=None, + has_attention_mask=False, + attention_probs_dropout_prob=0.0, + use_one_hot_embeddings=False, + initializer_range=0.02, + do_return_2d_tensor=False, + use_relative_positions=False, + compute_type=mstype.float32): + + super(BertAttention, self).__init__() + self.batch_size = batch_size + self.from_seq_length = from_seq_length + self.to_seq_length = to_seq_length + self.num_attention_heads = num_attention_heads + self.size_per_head = size_per_head + self.has_attention_mask = has_attention_mask + self.use_relative_positions = use_relative_positions + + self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) + self.reshape = P.Reshape() + self.shape_from_2d = (-1, from_tensor_width) + self.shape_to_2d = (-1, to_tensor_width) + weight = TruncatedNormal(initializer_range) + units = num_attention_heads * size_per_head + self.query_layer = nn.Dense(from_tensor_width, + units, + activation=query_act, + weight_init=weight).to_float(compute_type) + self.key_layer = nn.Dense(to_tensor_width, + units, + activation=key_act, + weight_init=weight).to_float(compute_type) + self.value_layer = nn.Dense(to_tensor_width, + units, + activation=value_act, + weight_init=weight).to_float(compute_type) + + self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) + self.shape_to = ( + batch_size, to_seq_length, num_attention_heads, size_per_head) + + self.matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.multiply = P.Mul() + self.transpose = P.Transpose() + self.trans_shape = (0, 2, 1, 3) + self.trans_shape_relative = (2, 0, 1, 3) + self.trans_shape_position = (1, 2, 0, 3) + self.multiply_data = Tensor([-10000.0,], dtype=compute_type) + self.batch_num = batch_size * num_attention_heads + self.matmul = P.BatchMatMul() + + self.softmax = nn.Softmax() + self.dropout = nn.Dropout(1 - attention_probs_dropout_prob) + + if self.has_attention_mask: + self.expand_dims = P.ExpandDims() + self.sub = P.Sub() + self.add = P.TensorAdd() + self.cast = P.Cast() + self.get_dtype = P.DType() + if do_return_2d_tensor: + self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) + else: + self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) + + self.cast_compute_type = SaturateCast(dst_type=compute_type) + if self.use_relative_positions: + self._generate_relative_positions_embeddings = \ + RelaPosEmbeddingsGenerator(length=to_seq_length, + depth=size_per_head, + max_relative_position=16, + initializer_range=initializer_range, + use_one_hot_embeddings=use_one_hot_embeddings) + + def construct(self, from_tensor, to_tensor, attention_mask): + # reshape 2d/3d input tensors to 2d + from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) + to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) + query_out = self.query_layer(from_tensor_2d) + key_out = self.key_layer(to_tensor_2d) + value_out = self.value_layer(to_tensor_2d) + + query_layer = self.reshape(query_out, self.shape_from) + query_layer = self.transpose(query_layer, self.trans_shape) + key_layer = self.reshape(key_out, self.shape_to) + key_layer = self.transpose(key_layer, self.trans_shape) + + attention_scores = self.matmul_trans_b(query_layer, key_layer) + + # use_relative_position, supplementary logic + if self.use_relative_positions: + # 'relations_keys' = [F|T, F|T, H] + relations_keys = self._generate_relative_positions_embeddings() + relations_keys = self.cast_compute_type(relations_keys) + # query_layer_t is [F, B, N, H] + query_layer_t = self.transpose(query_layer, self.trans_shape_relative) + # query_layer_r is [F, B * N, H] + query_layer_r = self.reshape(query_layer_t, + (self.from_seq_length, + self.batch_num, + self.size_per_head)) + # key_position_scores is [F, B * N, F|T] + key_position_scores = self.matmul_trans_b(query_layer_r, + relations_keys) + # key_position_scores_r is [F, B, N, F|T] + key_position_scores_r = self.reshape(key_position_scores, + (self.from_seq_length, + self.batch_size, + self.num_attention_heads, + self.from_seq_length)) + # key_position_scores_r_t is [B, N, F, F|T] + key_position_scores_r_t = self.transpose(key_position_scores_r, + self.trans_shape_position) + attention_scores = attention_scores + key_position_scores_r_t + + attention_scores = self.multiply(self.scores_mul, attention_scores) + + if self.has_attention_mask: + attention_mask = self.expand_dims(attention_mask, 1) + multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), + self.cast(attention_mask, self.get_dtype(attention_scores))) + + adder = self.multiply(multiply_out, self.multiply_data) + attention_scores = self.add(adder, attention_scores) + + attention_probs = self.softmax(attention_scores) + attention_probs = self.dropout(attention_probs) + + value_layer = self.reshape(value_out, self.shape_to) + value_layer = self.transpose(value_layer, self.trans_shape) + context_layer = self.matmul(attention_probs, value_layer) + + # use_relative_position, supplementary logic + if self.use_relative_positions: + # 'relations_values' = [F|T, F|T, H] + relations_values = self._generate_relative_positions_embeddings() + relations_values = self.cast_compute_type(relations_values) + # attention_probs_t is [F, B, N, T] + attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative) + # attention_probs_r is [F, B * N, T] + attention_probs_r = self.reshape( + attention_probs_t, + (self.from_seq_length, + self.batch_num, + self.to_seq_length)) + # value_position_scores is [F, B * N, H] + value_position_scores = self.matmul(attention_probs_r, + relations_values) + # value_position_scores_r is [F, B, N, H] + value_position_scores_r = self.reshape(value_position_scores, + (self.from_seq_length, + self.batch_size, + self.num_attention_heads, + self.size_per_head)) + # value_position_scores_r_t is [B, N, F, H] + value_position_scores_r_t = self.transpose(value_position_scores_r, + self.trans_shape_position) + context_layer = context_layer + value_position_scores_r_t + + context_layer = self.transpose(context_layer, self.trans_shape) + context_layer = self.reshape(context_layer, self.shape_return) + + return context_layer + + +class BertSelfAttention(nn.Cell): + """ + Apply self-attention. + + Args: + batch_size (int): Batch size of input dataset. + seq_length (int): Length of input sequence. + hidden_size (int): Size of the bert encoder layers. + num_attention_heads (int): Number of attention heads. Default: 12. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + seq_length, + hidden_size, + num_attention_heads=12, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError("The hidden size (%d) is not a multiple of the number " + "of attention heads (%d)" % (hidden_size, num_attention_heads)) + + self.size_per_head = int(hidden_size / num_attention_heads) + + self.attention = BertAttention( + batch_size=batch_size, + from_tensor_width=hidden_size, + to_tensor_width=hidden_size, + from_seq_length=seq_length, + to_seq_length=seq_length, + num_attention_heads=num_attention_heads, + size_per_head=self.size_per_head, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + use_relative_positions=use_relative_positions, + has_attention_mask=True, + do_return_2d_tensor=True, + compute_type=compute_type) + + self.output = BertOutput(in_channels=hidden_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + + def construct(self, input_tensor, attention_mask): + input_tensor = self.reshape(input_tensor, self.shape) + attention_output = self.attention(input_tensor, input_tensor, attention_mask) + output = self.output(attention_output, input_tensor) + return output + + +class BertEncoderCell(nn.Cell): + """ + Encoder cells used in BertTransformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the bert encoder layers. Default: 768. + seq_length (int): Length of input sequence. Default: 512. + num_attention_heads (int): Number of attention heads. Default: 12. + intermediate_size (int): Size of intermediate layer. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.02. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. + """ + def __init__(self, + batch_size, + hidden_size=768, + seq_length=512, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.02, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + enable_fused_layernorm=False): + super(BertEncoderCell, self).__init__() + self.attention = BertSelfAttention( + batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + self.intermediate = nn.Dense(in_channels=hidden_size, + out_channels=intermediate_size, + activation=hidden_act, + weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) + self.output = BertOutput(in_channels=intermediate_size, + out_channels=hidden_size, + initializer_range=initializer_range, + dropout_prob=hidden_dropout_prob, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + + def construct(self, hidden_states, attention_mask): + # self-attention + attention_output = self.attention(hidden_states, attention_mask) + # feed construct + intermediate_output = self.intermediate(attention_output) + # add and normalize + output = self.output(intermediate_output, attention_output) + return output + + +class BertTransformer(nn.Cell): + """ + Multi-layer bert transformer. + + Args: + batch_size (int): Batch size of input dataset. + hidden_size (int): Size of the encoder layers. + seq_length (int): Length of input sequence. + num_hidden_layers (int): Number of hidden layers in encoder cells. + num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. + intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. + attention_probs_dropout_prob (float): The dropout probability for + BertAttention. Default: 0.1. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. + hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. + use_relative_positions (bool): Specifies whether to use relative positions. Default: False. + hidden_act (str): Activation function used in the encoder cells. Default: "gelu". + compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + return_all_encoders (bool): Specifies whether to return all encoders. Default: False. + """ + def __init__(self, + batch_size, + hidden_size, + seq_length, + num_hidden_layers, + num_attention_heads=12, + intermediate_size=3072, + attention_probs_dropout_prob=0.1, + use_one_hot_embeddings=False, + initializer_range=0.02, + hidden_dropout_prob=0.1, + use_relative_positions=False, + hidden_act="gelu", + compute_type=mstype.float32, + return_all_encoders=False, + enable_fused_layernorm=False): + super(BertTransformer, self).__init__() + self.return_all_encoders = return_all_encoders + + layers = [] + for _ in range(num_hidden_layers): + layer = BertEncoderCell(batch_size=batch_size, + hidden_size=hidden_size, + seq_length=seq_length, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=initializer_range, + hidden_dropout_prob=hidden_dropout_prob, + use_relative_positions=use_relative_positions, + hidden_act=hidden_act, + compute_type=compute_type, + enable_fused_layernorm=enable_fused_layernorm) + layers.append(layer) + + self.layers = nn.CellList(layers) + + self.reshape = P.Reshape() + self.shape = (-1, hidden_size) + self.out_shape = (batch_size, seq_length, hidden_size) + + def construct(self, input_tensor, attention_mask): + prev_output = self.reshape(input_tensor, self.shape) + + all_encoder_layers = () + for layer_module in self.layers: + layer_output = layer_module(prev_output, attention_mask) + prev_output = layer_output + + if self.return_all_encoders: + layer_output = self.reshape(layer_output, self.out_shape) + all_encoder_layers = all_encoder_layers + (layer_output,) + + if not self.return_all_encoders: + prev_output = self.reshape(prev_output, self.out_shape) + all_encoder_layers = all_encoder_layers + (prev_output,) + return all_encoder_layers + + +class CreateAttentionMaskFromInputMask(nn.Cell): + """ + Create attention mask according to input mask. + + Args: + config (Class): Configuration for BertModel. + """ + def __init__(self, config): + super(CreateAttentionMaskFromInputMask, self).__init__() + self.input_mask_from_dataset = config.input_mask_from_dataset + self.input_mask = None + + if not self.input_mask_from_dataset: + self.input_mask = initializer( + "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor() + + self.cast = P.Cast() + self.reshape = P.Reshape() + self.shape = (config.batch_size, 1, config.seq_length) + self.broadcast_ones = initializer( + "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor() + self.batch_matmul = P.BatchMatMul() + + def construct(self, input_mask): + if not self.input_mask_from_dataset: + input_mask = self.input_mask + + input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32) + attention_mask = self.batch_matmul(self.broadcast_ones, input_mask) + return attention_mask + + +class BertModel(nn.Cell): + """ + Bidirectional Encoder Representations from Transformers. + + Args: + config (Class): Configuration for BertModel. + is_training (bool): True for training mode. False for eval mode. + use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. + """ + def __init__(self, + config, + is_training, + use_one_hot_embeddings=False): + super(BertModel, self).__init__() + config = copy.deepcopy(config) + if not is_training: + config.hidden_dropout_prob = 0.0 + config.attention_probs_dropout_prob = 0.0 + + self.input_mask_from_dataset = config.input_mask_from_dataset + self.token_type_ids_from_dataset = config.token_type_ids_from_dataset + self.batch_size = config.batch_size + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.num_hidden_layers = config.num_hidden_layers + self.embedding_size = config.hidden_size + self.token_type_ids = None + + self.last_idx = self.num_hidden_layers - 1 + output_embedding_shape = [self.batch_size, self.seq_length, + self.embedding_size] + + if not self.token_type_ids_from_dataset: + self.token_type_ids = initializer( + "zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor() + + self.bert_embedding_lookup = EmbeddingLookup( + vocab_size=config.vocab_size, + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range) + + self.bert_embedding_postprocessor = EmbeddingPostprocessor( + embedding_size=self.embedding_size, + embedding_shape=output_embedding_shape, + use_relative_positions=config.use_relative_positions, + use_token_type=True, + token_type_vocab_size=config.type_vocab_size, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=0.02, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + + self.bert_encoder = BertTransformer( + batch_size=self.batch_size, + hidden_size=self.hidden_size, + seq_length=self.seq_length, + num_attention_heads=config.num_attention_heads, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=config.intermediate_size, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + use_one_hot_embeddings=use_one_hot_embeddings, + initializer_range=config.initializer_range, + hidden_dropout_prob=config.hidden_dropout_prob, + use_relative_positions=config.use_relative_positions, + hidden_act=config.hidden_act, + compute_type=config.compute_type, + return_all_encoders=True, + enable_fused_layernorm=config.enable_fused_layernorm) + + self.cast = P.Cast() + self.dtype = config.dtype + self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.slice = P.StridedSlice() + + self.squeeze_1 = P.Squeeze(axis=1) + self.dense = nn.Dense(self.hidden_size, self.hidden_size, + activation="tanh", + weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) + self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) + + def construct(self, input_ids, token_type_ids, input_mask): + + # embedding + if not self.token_type_ids_from_dataset: + token_type_ids = self.token_type_ids + word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids) + embedding_output = self.bert_embedding_postprocessor(token_type_ids, + word_embeddings) + + # attention mask [batch_size, seq_length, seq_length] + attention_mask = self._create_attention_mask_from_input_mask(input_mask) + + # bert encoder + encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output), + attention_mask) + + sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) + + # pooler + sequence_slice = self.slice(sequence_output, + (0, 0, 0), + (self.batch_size, 1, self.hidden_size), + (1, 1, 1)) + first_token = self.squeeze_1(sequence_slice) + pooled_output = self.dense(first_token) + pooled_output = self.cast(pooled_output, self.dtype) + + return sequence_output, pooled_output, embedding_tables diff --git a/tests/st/networks/models/bert/src/cluener_evaluation.py b/tests/st/networks/models/bert/src/cluener_evaluation.py new file mode 100644 index 00000000000..c2c6770a4ad --- /dev/null +++ b/tests/st/networks/models/bert/src/cluener_evaluation.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +'''bert clue evaluation''' + +import json +import numpy as np +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +import tokenization +from sample_process import label_generation, process_one_example_p +from .evaluation_config import cfg +from .CRF import postprocess + +vocab_file = "./vocab.txt" +tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file) + +def process(model, text, sequence_length): + """ + process text. + """ + data = [text] + features = [] + res = [] + ids = [] + for i in data: + feature = process_one_example_p(tokenizer_, i, max_seq_len=sequence_length) + features.append(feature) + input_ids, input_mask, token_type_id = feature + input_ids = Tensor(np.array(input_ids), mstype.int32) + input_mask = Tensor(np.array(input_mask), mstype.int32) + token_type_id = Tensor(np.array(token_type_id), mstype.int32) + if cfg.use_crf: + backpointers, best_tag_id = model.predict(input_ids, input_mask, token_type_id, Tensor(1)) + best_path = postprocess(backpointers, best_tag_id) + logits = [] + for ele in best_path: + logits.extend(ele) + ids = logits + else: + logits = model.predict(input_ids, input_mask, token_type_id, Tensor(1)) + ids = logits.asnumpy() + ids = np.argmax(ids, axis=-1) + ids = list(ids) + res = label_generation(text, ids) + return res + +def submit(model, path, sequence_length): + """ + submit task + """ + data = [] + for line in open(path): + if not line.strip(): + continue + oneline = json.loads(line.strip()) + res = process(model, oneline["text"], sequence_length) + print("text", oneline["text"]) + print("res:", res) + data.append(json.dumps({"label": res}, ensure_ascii=False)) + open("ner_predict.json", "w").write("\n".join(data)) diff --git a/tests/st/networks/models/bert/src/config.py b/tests/st/networks/models/bert/src/config.py new file mode 100644 index 00000000000..d1062b78eec --- /dev/null +++ b/tests/st/networks/models/bert/src/config.py @@ -0,0 +1,118 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in dataset.py, run_pretrain.py +""" +from easydict import EasyDict as edict +import mindspore.common.dtype as mstype +from .bert_model import BertConfig +cfg = edict({ + 'bert_network': 'base', + 'loss_scale_value': 65536, + 'scale_factor': 2, + 'scale_window': 1000, + 'optimizer': 'Lamb', + 'AdamWeightDecayDynamicLR': edict({ + 'learning_rate': 3e-5, + 'end_learning_rate': 1e-10, + 'power': 5.0, + 'weight_decay': 1e-5, + 'eps': 1e-6, + 'warmup_steps': 10000, + }), + 'Lamb': edict({ + 'start_learning_rate': 3e-5, + 'end_learning_rate': 1e-10, + 'power': 10.0, + 'warmup_steps': 10000, + 'weight_decay': 0.01, + 'eps': 1e-6, + }), + 'Momentum': edict({ + 'learning_rate': 2e-5, + 'momentum': 0.9, + }), +}) + +''' +Including two kinds of network: \ +base: Goole BERT-base(the base version of BERT model). +large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \ + Functional Relative Posetional Encoding as an effective positional encoding scheme). +''' +if cfg.bert_network == 'base': + bert_net_cfg = BertConfig( + batch_size=32, + seq_length=128, + vocab_size=21136, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16 + ) +if cfg.bert_network == 'nezha': + bert_net_cfg = BertConfig( + batch_size=32, + seq_length=128, + vocab_size=21136, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=True, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16 + ) +if cfg.bert_network == 'large': + bert_net_cfg = BertConfig( + batch_size=16, + seq_length=512, + vocab_size=30528, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=True + ) diff --git a/tests/st/networks/models/bert/src/dataset.py b/tests/st/networks/models/bert/src/dataset.py new file mode 100644 index 00000000000..1828fac4544 --- /dev/null +++ b/tests/st/networks/models/bert/src/dataset.py @@ -0,0 +1,59 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in run_pretrain.py +""" +import os +import mindspore.common.dtype as mstype +import mindspore.dataset.engine.datasets as de +import mindspore.dataset.transforms.c_transforms as C +from mindspore import log as logger +from .config import bert_net_cfg + + +def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true", + data_sink_steps=1, data_dir=None, schema_dir=None): + """create train dataset""" + # apply repeat operations + repeat_count = epoch_size + files = os.listdir(data_dir) + data_files = [] + for file_name in files: + if "tfrecord" in file_name: + data_files.append(os.path.join(data_dir, file_name)) + ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", + "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], + shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, + shard_equal_rows=True) + ori_dataset_size = ds.get_dataset_size() + new_size = ori_dataset_size + if enable_data_sink == "true": + new_size = data_sink_steps * bert_net_cfg.batch_size + ds.set_dataset_size(new_size) + new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size()) + type_cast_op = C.TypeCast(mstype.int32) + ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) + ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) + ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + # apply batch operations + ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) + ds = ds.repeat(new_repeat_count) + logger.info("data size: {}".format(ds.get_dataset_size())) + logger.info("repeatcount: {}".format(ds.get_repeat_count())) + return ds, new_repeat_count diff --git a/tests/st/networks/models/bert/src/evaluation_config.py b/tests/st/networks/models/bert/src/evaluation_config.py new file mode 100644 index 00000000000..b18c5643b00 --- /dev/null +++ b/tests/st/networks/models/bert/src/evaluation_config.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +config settings, will be used in finetune.py +""" + +from easydict import EasyDict as edict +import mindspore.common.dtype as mstype +from .bert_model import BertConfig + +cfg = edict({ + 'task': 'NER', + 'num_labels': 41, + 'data_file': '/your/path/evaluation.tfrecord', + 'schema_file': '/your/path/schema.json', + 'finetune_ckpt': '/your/path/your.ckpt', + 'use_crf': False, + 'clue_benchmark': False, +}) + +bert_net_cfg = BertConfig( + batch_size=16 if not cfg.clue_benchmark else 1, + seq_length=128, + vocab_size=21128, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, +) diff --git a/tests/st/networks/models/bert/src/finetune_config.py b/tests/st/networks/models/bert/src/finetune_config.py new file mode 100644 index 00000000000..e92842489b9 --- /dev/null +++ b/tests/st/networks/models/bert/src/finetune_config.py @@ -0,0 +1,119 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +config settings, will be used in finetune.py +""" + +from easydict import EasyDict as edict +import mindspore.common.dtype as mstype +from .bert_model import BertConfig + +cfg = edict({ + 'task': 'NER', + 'num_labels': 41, + 'data_file': '/your/path/train.tfrecord', + 'schema_file': '/your/path/schema.json', + 'epoch_num': 5, + 'ckpt_prefix': 'bert', + 'ckpt_dir': None, + 'pre_training_ckpt': '/your/path/pre_training.ckpt', + 'use_crf': False, + 'optimizer': 'Lamb', + 'AdamWeightDecayDynamicLR': edict({ + 'learning_rate': 2e-5, + 'end_learning_rate': 1e-7, + 'power': 1.0, + 'weight_decay': 1e-5, + 'eps': 1e-6, + }), + 'Lamb': edict({ + 'start_learning_rate': 2e-5, + 'end_learning_rate': 1e-7, + 'power': 1.0, + 'decay_filter': lambda x: False, + }), + 'Momentum': edict({ + 'learning_rate': 2e-5, + 'momentum': 0.9, + }), +}) + +bert_net_cfg = BertConfig( + batch_size=16, + seq_length=128, + vocab_size=21128, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, +) + +tag_to_index = { + "O": 0, + "S_address": 1, + "B_address": 2, + "M_address": 3, + "E_address": 4, + "S_book": 5, + "B_book": 6, + "M_book": 7, + "E_book": 8, + "S_company": 9, + "B_company": 10, + "M_company": 11, + "E_company": 12, + "S_game": 13, + "B_game": 14, + "M_game": 15, + "E_game": 16, + "S_government": 17, + "B_government": 18, + "M_government": 19, + "E_government": 20, + "S_movie": 21, + "B_movie": 22, + "M_movie": 23, + "E_movie": 24, + "S_name": 25, + "B_name": 26, + "M_name": 27, + "E_name": 28, + "S_organization": 29, + "B_organization": 30, + "M_organization": 31, + "E_organization": 32, + "S_position": 33, + "B_position": 34, + "M_position": 35, + "E_position": 36, + "S_scene": 37, + "B_scene": 38, + "M_scene": 39, + "E_scene": 40, + "": 41, + "": 42 +} diff --git a/tests/st/networks/models/bert/src/fused_layer_norm.py b/tests/st/networks/models/bert/src/fused_layer_norm.py new file mode 100644 index 00000000000..ee3160b036f --- /dev/null +++ b/tests/st/networks/models/bert/src/fused_layer_norm.py @@ -0,0 +1,121 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""fused layernorm""" +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.ops.primitive import constexpr +import mindspore.common.dtype as mstype +from mindspore.nn.cell import Cell + +import numpy as np + + +__all__ = ['FusedLayerNorm'] + +@constexpr +def get_shape_for_norm(x_shape, begin_norm_axis): + print("input_shape: ", x_shape) + norm_shape = x_shape[begin_norm_axis:] + output_shape = (1, -1, 1, int(np.prod(norm_shape))) + print("output_shape: ", output_shape) + return output_shape + +class FusedLayerNorm(Cell): + r""" + Applies Layer Normalization over a mini-batch of inputs. + + Layer normalization is widely used in recurrent neural networks. It applies + normalization over a mini-batch of inputs for each single training case as described + in the paper `Layer Normalization `_. Unlike batch + normalization, layer normalization performs exactly the same computation at training and + testing times. It can be described using the following formula. It is applied across all channels + and pixel but only one batch size. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Args: + normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis + `begin_norm_axis ... R - 1`. + begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions + `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1. + begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters + will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with + the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + use_batch_nrom (bool): Whether use batchnorm to preocess. + + Inputs: + - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`, + and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`. + + Outputs: + Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`. + + Examples: + >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) + >>> shape1 = x.shape()[1:] + >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) + >>> m(x) + """ + def __init__(self, + normalized_shape, + begin_norm_axis=-1, + begin_params_axis=-1, + gamma_init='ones', + beta_init='zeros', + use_batch_norm=False): + super(FusedLayerNorm, self).__init__() + if not isinstance(normalized_shape, (tuple, list)): + raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}." + .format(normalized_shape, type(normalized_shape))) + self.normalized_shape = normalized_shape + self.begin_norm_axis = begin_norm_axis + self.begin_params_axis = begin_params_axis + self.gamma = Parameter(initializer( + gamma_init, normalized_shape), name="gamma") + self.beta = Parameter(initializer( + beta_init, normalized_shape), name="beta") + self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis) + + self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5) + self.use_batch_norm = use_batch_norm + + def construct(self, input_x): + if self.use_batch_norm and self.training: + ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0) + zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0) + shape_x = F.shape(input_x) + norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis) + input_x = F.reshape(input_x, norm_shape) + output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None) + output = F.reshape(output, shape_x) + y = output * self.gamma + self.beta + else: + y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) + return y + + def extend_repr(self): + """Display instance object as string.""" + s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format( + self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) + return s diff --git a/tests/st/networks/models/bert/src/sample_process.py b/tests/st/networks/models/bert/src/sample_process.py new file mode 100644 index 00000000000..59f3e76a31a --- /dev/null +++ b/tests/st/networks/models/bert/src/sample_process.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""process txt""" + +import re +import json + +def process_one_example_p(tokenizer, text, max_seq_len=128): + """process one testline""" + textlist = list(text) + tokens = [] + for _, word in enumerate(textlist): + token = tokenizer.tokenize(word) + tokens.extend(token) + if len(tokens) >= max_seq_len - 1: + tokens = tokens[0:(max_seq_len - 2)] + ntokens = [] + segment_ids = [] + label_ids = [] + ntokens.append("[CLS]") + segment_ids.append(0) + for _, token in enumerate(tokens): + ntokens.append(token) + segment_ids.append(0) + ntokens.append("[SEP]") + segment_ids.append(0) + input_ids = tokenizer.convert_tokens_to_ids(ntokens) + input_mask = [1] * len(input_ids) + while len(input_ids) < max_seq_len: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + label_ids.append(0) + ntokens.append("**NULL**") + assert len(input_ids) == max_seq_len + assert len(input_mask) == max_seq_len + assert len(segment_ids) == max_seq_len + + feature = (input_ids, input_mask, segment_ids) + return feature + +def label_generation(text, probs): + """generate label""" + data = [text] + probs = [probs] + result = [] + label2id = json.loads(open("./label2id.json").read()) + id2label = [k for k, v in label2id.items()] + + for index, prob in enumerate(probs): + for v in prob[1:len(data[index]) + 1]: + result.append(id2label[int(v)]) + + labels = {} + start = None + index = 0 + for _, t in zip("".join(data), result): + if re.search("^[BS]", t): + if start is not None: + label = result[index - 1][2:] + if labels.get(label): + te_ = text[start:index] + labels[label][te_] = [[start, index - 1]] + else: + te_ = text[start:index] + labels[label] = {te_: [[start, index - 1]]} + start = index + if re.search("^O", t): + if start is not None: + label = result[index - 1][2:] + if labels.get(label): + te_ = text[start:index] + labels[label][te_] = [[start, index - 1]] + else: + te_ = text[start:index] + labels[label] = {te_: [[start, index - 1]]} + start = None + index += 1 + if start is not None: + label = result[start][2:] + if labels.get(label): + te_ = text[start:index] + labels[label][te_] = [[start, index - 1]] + else: + te_ = text[start:index] + labels[label] = {te_: [[start, index - 1]]} + return labels diff --git a/tests/st/networks/models/bert/src/utils.py b/tests/st/networks/models/bert/src/utils.py new file mode 100644 index 00000000000..e4dd3e7b472 --- /dev/null +++ b/tests/st/networks/models/bert/src/utils.py @@ -0,0 +1,263 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Functional Cells used in Bert finetune and evaluation. +''' + +import mindspore.nn as nn +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.train.parallel_utils import ParallelMode +from mindspore.communication.management import get_group_size +from mindspore import context +from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel +from .bert_for_pre_training import clip_grad +from .CRF import CRF + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + +class BertFinetuneCell(nn.Cell): + """ + Especifically defined for finetuning where only four inputs tensor are needed. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + + super(BertFinetuneCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + def construct(self, + input_ids, + input_mask, + token_type_id, + label_ids, + sens=None): + + + weights = self.weights + init = self.alloc_status() + loss = self.network(input_ids, + input_mask, + token_type_id, + label_ids) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + label_ids, + self.cast(scaling_sens, + mstype.float32)) + clear_before_grad = self.clear_before_grad(init) + F.control_depend(loss, init) + self.depend_parameter_use(clear_before_grad, scaling_sens) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.reducer_flag: + grads = self.grad_reducer(grads) + flag = self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + F.control_depend(grads, flag) + F.control_depend(flag, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond) + return F.depend(ret, succ) + +class BertCLSModel(nn.Cell): + """ + This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), + LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final + logits as the results of log_softmax is propotional to that of softmax. + """ + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): + super(BertCLSModel, self).__init__() + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cast = P.Cast() + self.weight_init = TruncatedNormal(config.initializer_range) + self.log_softmax = P.LogSoftmax(axis=-1) + self.dtype = config.dtype + self.num_labels = num_labels + self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) + self.dropout = nn.Dropout(1 - dropout_prob) + + def construct(self, input_ids, input_mask, token_type_id): + _, pooled_output, _ = \ + self.bert(input_ids, token_type_id, input_mask) + cls = self.cast(pooled_output, self.dtype) + cls = self.dropout(cls) + logits = self.dense_1(cls) + logits = self.cast(logits, self.dtype) + log_probs = self.log_softmax(logits) + return log_probs + + +class BertNERModel(nn.Cell): + """ + This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11). + The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. + """ + def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0, + use_one_hot_embeddings=False): + super(BertNERModel, self).__init__() + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cast = P.Cast() + self.weight_init = TruncatedNormal(config.initializer_range) + self.log_softmax = P.LogSoftmax(axis=-1) + self.dtype = config.dtype + self.num_labels = num_labels + self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) + self.dropout = nn.Dropout(1 - dropout_prob) + self.reshape = P.Reshape() + self.shape = (-1, config.hidden_size) + self.use_crf = use_crf + self.origin_shape = (config.batch_size, config.seq_length, self.num_labels) + + def construct(self, input_ids, input_mask, token_type_id): + sequence_output, _, _ = \ + self.bert(input_ids, token_type_id, input_mask) + seq = self.dropout(sequence_output) + seq = self.reshape(seq, self.shape) + logits = self.dense_1(seq) + logits = self.cast(logits, self.dtype) + if self.use_crf: + return_value = self.reshape(logits, self.origin_shape) + else: + return_value = self.log_softmax(logits) + return return_value + +class CrossEntropyCalculation(nn.Cell): + """ + Cross Entropy loss + """ + def __init__(self, is_training=True): + super(CrossEntropyCalculation, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reduce_mean = P.ReduceMean() + self.reshape = P.Reshape() + self.last_idx = (-1,) + self.neg = P.Neg() + self.cast = P.Cast() + self.is_training = is_training + + def construct(self, logits, label_ids, num_labels): + if self.is_training: + label_ids = self.reshape(label_ids, self.last_idx) + one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value) + per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx)) + loss = self.reduce_mean(per_example_loss, self.last_idx) + return_value = self.cast(loss, mstype.float32) + else: + return_value = logits * 1.0 + return return_value + +class BertCLS(nn.Cell): + """ + Train interface for classification finetuning task. + """ + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): + super(BertCLS, self).__init__() + self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) + self.loss = CrossEntropyCalculation(is_training) + self.num_labels = num_labels + def construct(self, input_ids, input_mask, token_type_id, label_ids): + log_probs = self.bert(input_ids, input_mask, token_type_id) + loss = self.loss(log_probs, label_ids, self.num_labels) + return loss + + +class BertNER(nn.Cell): + """ + Train interface for sequence labeling finetuning task. + """ + def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0, + use_one_hot_embeddings=False): + super(BertNER, self).__init__() + self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings) + if use_crf: + if not tag_to_index: + raise Exception("The dict for tag-index mapping should be provided for CRF.") + self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training) + else: + self.loss = CrossEntropyCalculation(is_training) + self.num_labels = num_labels + self.use_crf = use_crf + def construct(self, input_ids, input_mask, token_type_id, label_ids): + logits = self.bert(input_ids, input_mask, token_type_id) + if self.use_crf: + loss = self.loss(logits, label_ids) + else: + loss = self.loss(logits, label_ids, self.num_labels) + return loss diff --git a/tests/ut/python/model/test_bert.py b/tests/ut/python/model/test_bert.py deleted file mode 100644 index 840f594be35..00000000000 --- a/tests/ut/python/model/test_bert.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test bert cell """ -import numpy as np -import pytest - -from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertModel -from ....dataset_mock import MindData - - -def map_bert(record): - target_data = {'input_ids': None, 'input_mask': None, - 'segment_ids': None, 'next_sentence_labels': None, - 'masked_lm_positions': None, 'masked_lm_ids': None, - 'masked_lm_weights': None} - - sample = dt.parse_single_example(record, target_data) - - return sample['input_ids'], sample['input_mask'], sample['segment_ids'], \ - sample['next_sentence_labels'], sample['masked_lm_positions'], \ - sample['masked_lm_ids'], sample['masked_lm_weights'] - - -def test_bert_model(): - # test for config.hidden_size % config.num_attention_heads != 0 - config_error = BertConfig(32, hidden_size=512, num_attention_heads=10) - with pytest.raises(ValueError): - BertModel(config_error, True) - - -def get_dataset(batch_size=1): - dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32) - dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1), - (batch_size, 20), (batch_size, 20), (batch_size, 20)) - - dataset = MindData(size=2, batch_size=batch_size, - np_types=dataset_types, - output_shapes=dataset_shapes, - input_indexs=(0, 1)) - return dataset diff --git a/tests/ut/python/model/test_bert_cell.py b/tests/ut/python/model/test_bert_cell.py deleted file mode 100644 index 175341767af..00000000000 --- a/tests/ut/python/model/test_bert_cell.py +++ /dev/null @@ -1,437 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test bert of graph compile """ -import functools -import numpy as np - -import mindspore.common.dtype as mstype -import mindspore.nn as nn -import mindspore.ops.composite as C -from mindspore.ops import functional as F -from mindspore.common.initializer import TruncatedNormal -from mindspore.common.parameter import ParameterTuple -from mindspore.common.tensor import Tensor -from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput -from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad -from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \ - EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \ - RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \ - BertEncoderCell, BertTransformer, CreateAttentionMaskFromInputMask, BertModel -from mindspore.nn.layer.basic import Norm -from mindspore.nn.optim import AdamWeightDecay, AdamWeightDecayDynamicLR -from ....mindspore_test_framework.mindspore_test import mindspore_test -from ....mindspore_test_framework.pipeline.forward.compile_forward import \ - pipeline_for_compile_forward_ge_graph_for_case_by_case_config -from ....mindspore_test_framework.pipeline.gradient.compile_gradient import \ - pipeline_for_compile_grad_ge_graph_for_case_by_case_config -from ....ops_common import convert - - -def bert_trans(): - """bert_trans""" - net = BertTransformer(batch_size=1, - hidden_size=768, - seq_length=128, - num_hidden_layers=1, - num_attention_heads=12, - intermediate_size=768, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - use_relative_positions=False, - hidden_act="gelu", - compute_type=mstype.float32, - return_all_encoders=True) - net.set_train() - return net - - -def set_train(net): - net.set_train() - return net - - -class NetForAdam(nn.Cell): - def __init__(self): - super(NetForAdam, self).__init__() - self.dense = nn.Dense(64, 10) - - def construct(self, x): - x = self.dense(x) - return x - - -class TrainStepWrapForAdam(nn.Cell): - """TrainStepWrapForAdam definition""" - - def __init__(self, network): - super(TrainStepWrapForAdam, self).__init__() - self.network = network - self.weights = ParameterTuple(network.get_parameters()) - self.optimizer = AdamWeightDecay(self.weights) - self.hyper_map = C.HyperMap() - - def construct(self, x, sens): - weights = self.weights - grads = C.grad_by_list_with_sens(self.network, weights)(x, sens) - grads = self.hyper_map(F.partial(clip_grad, 1, 1.0), grads) - return self.optimizer(grads) - - -class TrainStepWrapForAdamDynamicLr(nn.Cell): - """TrainStepWrapForAdamDynamicLr definition""" - - def __init__(self, network): - super(TrainStepWrapForAdamDynamicLr, self).__init__() - self.network = network - self.weights = ParameterTuple(network.get_parameters()) - self.optimizer = AdamWeightDecayDynamicLR(self.weights, 10) - self.sens = Tensor(np.ones(shape=(1, 10)).astype(np.float32)) - - def construct(self, x): - weights = self.weights - grads = C.grad_by_list_with_sens(self.network, weights)(x, self.sens) - return self.optimizer(grads) - - -class TempC2Wrap(nn.Cell): - def __init__(self, op, c1=None, c2=None,): - super(TempC2Wrap, self).__init__() - self.op = op - self.c1 = c1 - self.c2 = c2 - self.hyper_map = C.HyperMap() - - def construct(self, x1): - x = self.hyper_map(F.partial(self.op, self.c1, self.c2), x1) - return x - - -test_case_cell_ops = [ - ('Norm_keepdims', { - 'block': Norm(keep_dims=True), - 'desc_inputs': [[1, 3, 4, 4]], - 'desc_bprop': [[1]]}), - ('SaturateCast', { - 'block': SaturateCast(), - 'desc_inputs': [[1, 3, 4, 4]], - 'desc_bprop': [[1, 3, 4, 4]]}), - ('RelaPosMatrixGenerator_0', { - 'block': RelaPosMatrixGenerator(length=128, max_relative_position=16), - 'desc_inputs': [], - 'desc_bprop': [[128, 128]], - 'skip': ['backward']}), - ('RelaPosEmbeddingsGenerator_0', { - 'block': RelaPosEmbeddingsGenerator(length=128, depth=512, - max_relative_position=16, - initializer_range=0.2), - 'desc_inputs': [], - 'desc_bprop': [[16384, 512]], - 'skip': ['backward']}), - ('RelaPosEmbeddingsGenerator_1', { - 'block': RelaPosEmbeddingsGenerator(length=128, depth=512, - max_relative_position=16, - initializer_range=0.2, - use_one_hot_embeddings=False), - 'desc_inputs': [], - 'desc_bprop': [[128, 128, 512]], - 'skip': ['backward']}), - ('RelaPosEmbeddingsGenerator_2', { - 'block': RelaPosEmbeddingsGenerator(length=128, depth=64, - max_relative_position=16, - initializer_range=0.2, - use_one_hot_embeddings=False), - 'desc_inputs': [], - 'desc_bprop': [[128, 128, 64]], - 'skip': ['backward']}), - ('BertAttention_0', { - 'block': BertAttention(batch_size=64, - from_tensor_width=768, - to_tensor_width=768, - from_seq_length=128, - to_seq_length=128, - num_attention_heads=12, - size_per_head=64, - query_act=None, - key_act=None, - value_act=None, - has_attention_mask=True, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - do_return_2d_tensor=True, - use_relative_positions=False, - compute_type=mstype.float32), - 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]], - 'desc_bprop': [[8192, 768]]}), - ('BertAttention_1', { - 'block': BertAttention(batch_size=64, - from_tensor_width=768, - to_tensor_width=768, - from_seq_length=128, - to_seq_length=128, - num_attention_heads=12, - size_per_head=64, - query_act=None, - key_act=None, - value_act=None, - has_attention_mask=True, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - do_return_2d_tensor=True, - use_relative_positions=True, - compute_type=mstype.float32), - 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]], - 'desc_bprop': [[8192, 768]]}), - ('BertAttention_2', { - 'block': BertAttention(batch_size=64, - from_tensor_width=768, - to_tensor_width=768, - from_seq_length=128, - to_seq_length=128, - num_attention_heads=12, - size_per_head=64, - query_act=None, - key_act=None, - value_act=None, - has_attention_mask=False, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - do_return_2d_tensor=True, - use_relative_positions=True, - compute_type=mstype.float32), - 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]], - 'desc_bprop': [[8192, 768]]}), - ('BertAttention_3', { - 'block': BertAttention(batch_size=64, - from_tensor_width=768, - to_tensor_width=768, - from_seq_length=128, - to_seq_length=128, - num_attention_heads=12, - size_per_head=64, - query_act=None, - key_act=None, - value_act=None, - has_attention_mask=True, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - do_return_2d_tensor=False, - use_relative_positions=True, - compute_type=mstype.float32), - 'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]], - 'desc_bprop': [[8192, 768]]}), - ('BertOutput', { - 'block': BertOutput(in_channels=768, - out_channels=768, - initializer_range=0.02, - dropout_prob=0.1), - 'desc_inputs': [[8192, 768], [8192, 768]], - 'desc_bprop': [[8192, 768]]}), - ('BertSelfAttention_0', { - 'block': BertSelfAttention(batch_size=64, - seq_length=128, - hidden_size=768, - num_attention_heads=12, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - compute_type=mstype.float32), - 'desc_inputs': [[64, 128, 768], [64, 128, 128]], - 'desc_bprop': [[8192, 768]]}), - ('BertEncoderCell', { - 'block': BertEncoderCell(batch_size=64, - hidden_size=768, - seq_length=128, - num_attention_heads=12, - intermediate_size=768, - attention_probs_dropout_prob=0.02, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - hidden_act="gelu", - compute_type=mstype.float32), - 'desc_inputs': [[64, 128, 768], [64, 128, 128]], - 'desc_bprop': [[8192, 768]]}), - ('BertTransformer_0', { - 'block': BertTransformer(batch_size=1, - hidden_size=768, - seq_length=128, - num_hidden_layers=1, - num_attention_heads=12, - intermediate_size=768, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - use_relative_positions=False, - hidden_act="gelu", - compute_type=mstype.float32, - return_all_encoders=True), - 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}), - ('BertTransformer_1', { - 'block': BertTransformer(batch_size=64, - hidden_size=768, - seq_length=128, - num_hidden_layers=2, - num_attention_heads=12, - intermediate_size=768, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - use_relative_positions=True, - hidden_act="gelu", - compute_type=mstype.float32, - return_all_encoders=False), - 'desc_inputs': [[64, 128, 768], [64, 128, 128]]}), - ('EmbeddingLookup', { - 'block': EmbeddingLookup(vocab_size=32000, - embedding_size=768, - embedding_shape=[1, 128, 768], - use_one_hot_embeddings=False, - initializer_range=0.02), - 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32))], - 'desc_bprop': [[1, 128, 768], [1, 128, 768]], - 'num_output': 2}), - ('EmbeddingPostprocessor', { - 'block': EmbeddingPostprocessor(embedding_size=768, - embedding_shape=[1, 128, 768], - use_token_type=True, - token_type_vocab_size=16, - use_one_hot_embeddings=False, - initializer_range=0.02, - max_position_embeddings=512, - dropout_prob=0.1), - 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), [1, 128, 768]], - 'desc_bprop': [[1, 128, 768]]}), - ('CreateAttentionMaskFromInputMask', { - 'block': CreateAttentionMaskFromInputMask(config=BertConfig(batch_size=1)), - 'desc_inputs': [[128]], - 'desc_bprop': [[1, 128, 128]]}), - ('BertOutput_0', { - 'block': BertOutput(in_channels=768, - out_channels=768, - initializer_range=0.02, - dropout_prob=0.1), - 'desc_inputs': [[1, 768], [1, 768]], - 'desc_bprop': [[1, 768]]}), - ('BertTransformer_2', { - 'block': bert_trans(), - 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}), - - ('BertModel', { - 'block': BertModel(config=BertConfig(batch_size=1, - num_hidden_layers=1, - intermediate_size=768, - token_type_ids_from_dataset=False), - is_training=True), - 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), - Tensor(np.random.rand(128).astype(np.int32)), [128]], - 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], - 'num_output': 3}), - - ('BertModel_1', { - 'block': BertModel(config=BertConfig(batch_size=1, - num_hidden_layers=1, - intermediate_size=768, - token_type_ids_from_dataset=False), - is_training=False), - 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), - Tensor(np.random.rand(128).astype(np.int32)), [128]], - 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], - 'num_output': 3}), - - ('BertModel_2', { - 'block': BertModel(config=BertConfig(batch_size=1, - num_hidden_layers=1, - intermediate_size=768, - token_type_ids_from_dataset=False, - input_mask_from_dataset=False), - is_training=True), - 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), - Tensor(np.random.rand(128).astype(np.int32)), [128]], - 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], - 'num_output': 3}), - - ('BertPretrainingLoss', { - 'block': BertPretrainingLoss(config=BertConfig(batch_size=1)), - 'desc_inputs': [[32000], [20, 2], Tensor(np.array([1]).astype(np.int32)), - [20], Tensor(np.array([20]).astype(np.int32))], - 'desc_bprop': [[1]], - 'num_output': 1}), - ('Dense_1', { - 'block': nn.Dense(in_channels=768, - out_channels=3072, - activation='gelu', - weight_init=TruncatedNormal(0.02)), - 'desc_inputs': [[3, 768]], - 'desc_bprop': [[3, 3072]]}), - ('Dense_2', { - 'block': set_train(nn.Dense(in_channels=768, - out_channels=3072, - activation='gelu', - weight_init=TruncatedNormal(0.02),)), - 'desc_inputs': [[3, 768]], - 'desc_bprop': [[3, 3072]]}), - ('GetNextSentenceOutput', { - 'block': GetNextSentenceOutput(BertConfig(batch_size=1)), - 'desc_inputs': [[128, 768]], - 'desc_bprop': [[128, 2]]}), - ('Adam_1', { - 'block': set_train(TrainStepWrapForAdam(NetForAdam())), - 'desc_inputs': [[1, 64], [1, 10]], - 'skip': ['backward']}), - ('Adam_2', { - 'block': set_train(TrainStepWrapForAdam(GetNextSentenceOutput(BertConfig(batch_size=1)))), - 'desc_inputs': [[128, 768], [128, 2]], - 'skip': ['backward']}), - ('AdamWeightDecayDynamicLR', { - 'block': set_train(TrainStepWrapForAdamDynamicLr(NetForAdam())), - 'desc_inputs': [[1, 64]], - 'skip': ['backward']}), - ('ClipGradients', { - 'block': TempC2Wrap(clip_grad, 1, 1.0), - 'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])], - 'skip': ['backward', 'exec']}), -] - -test_case = functools.reduce(lambda x, y: x + y, [test_case_cell_ops]) -# use -k to select certain testcast -# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm - - -test_exec_case = filter(lambda x: 'skip' not in x[1] or - 'exec' not in x[1]['skip'], test_case) -test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or - 'backward' not in x[1]['skip'] and 'backward_exec' - not in x[1]['skip'], test_case) -test_check_gradient_case = filter(lambda x: 'skip' not in x[1] or - 'backward' not in x[1]['skip'] and 'backward_exec' - not in x[1]['skip'], test_case) - - -@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) -def test_exec(): - return test_exec_case - - -@mindspore_test(pipeline_for_compile_grad_ge_graph_for_case_by_case_config) -def test_backward_exec(): - return test_backward_exec_case diff --git a/tests/ut/python/nn/test_embedding.py b/tests/ut/python/nn/test_embedding.py deleted file mode 100644 index e0f2f78f572..00000000000 --- a/tests/ut/python/nn/test_embedding.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test_embedding """ -import numpy as np - -from mindspore import Tensor -from mindspore import dtype as mstype -from mindspore.model_zoo.Bert_NEZHA import EmbeddingLookup, EmbeddingPostprocessor -from ..ut_filter import non_graph_engine - - -@non_graph_engine -def test_check_embedding_lookup_1(): - m = EmbeddingLookup(vocab_size=32000, - embedding_size=768, - embedding_shape=[1, 128, 768], - use_one_hot_embeddings=False) - m(Tensor(np.ones([128]), mstype.int32)) - - -@non_graph_engine -def test_check_embedding_lookup_2(): - m = EmbeddingLookup(vocab_size=32000, - embedding_size=768, - embedding_shape=[1, 128, 768], - use_one_hot_embeddings=True) - m(Tensor(np.ones([128]), mstype.int32)) - - -@non_graph_engine -def test_check_embedding_lookup_3(): - m = EmbeddingLookup(vocab_size=32000, - embedding_size=768, - embedding_shape=[1, 128, 768], - use_one_hot_embeddings=True, - initializer_range=0.01) - m(Tensor(np.ones([128]), mstype.int32)) - - -@non_graph_engine -def test_embedding_post_1(): - m = EmbeddingPostprocessor(embedding_size=768, - embedding_shape=[1, 128, 768], - use_token_type=True) - m(Tensor(np.ones([128]), mstype.int32), Tensor(np.ones([1, 128, 768]), mstype.float32)) - - -@non_graph_engine -def test_embedding_post_2(): - m = EmbeddingPostprocessor(embedding_size=768, - embedding_shape=[1, 128, 768], - use_token_type=True, - initializer_range=0.3) - m(Tensor(np.ones([128]), mstype.int32), Tensor(np.ones([1, 128, 768]), mstype.float32))