update bert scripts according to rules of modelzoo

This commit is contained in:
chenhaozhe 2020-05-21 17:22:58 +08:00
parent 02f33a17b5
commit b6aceddeab
35 changed files with 2737 additions and 620 deletions

View File

@ -308,7 +308,7 @@ def get_bprop_softmax(self):
axis = self.axis
def bprop(x, out, dout):
dx = mul(sub(dout, sum_func(mul(dout, out), axis)), out)
dx = mul(out, sub(dout, sum_func(mul(out, dout), axis)))
return (dx,)
return bprop

View File

@ -16,12 +16,12 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model.
``` bash
sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR
sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR
```
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model.
``` bash
sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH
sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH
```
### Fine-Tuning

View File

@ -19,8 +19,6 @@ Bert evaluation script.
import os
import numpy as np
from evaluation_config import cfg, bert_net_cfg
from utils import BertNER, BertCLS
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.common.tensor import Tensor
@ -28,9 +26,11 @@ import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from CRF import postprocess
from cluener_evaluation import submit
from finetune_config import tag_to_index
from src.evaluation_config import cfg, bert_net_cfg
from src.utils import BertNER, BertCLS
from src.CRF import postprocess
from src.cluener_evaluation import submit
from src.finetune_config import tag_to_index
class Accuracy():
'''

View File

@ -18,8 +18,8 @@ Bert finetune script.
'''
import os
from utils import BertFinetuneCell, BertCLS, BertNER
from finetune_config import cfg, bert_net_cfg, tag_to_index
from src.utils import BertFinetuneCell, BertCLS, BertNER
from src.finetune_config import cfg, bert_net_cfg, tag_to_index
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
from mindspore import context

View File

@ -26,10 +26,10 @@ from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
from dataset import create_bert_dataset
from config import cfg, bert_net_cfg
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg
_current_dir = os.path.dirname(os.path.realpath(__file__))
class LossCallBack(Callback):
@ -48,10 +48,8 @@ class LossCallBack(Callback):
self._per_print_times = per_print_times
def step_end(self, run_context):
cb_params = run_context.original_args()
with open("./loss.log", "a+") as f:
f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)))
f.write('\n')
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)))
def run_pretrain():
"""pre-train bert_clue"""
@ -81,6 +79,11 @@ def run_pretrain():
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num)
from mindspore.parallel._auto_parallel_context import auto_parallel_context
if bert_net_cfg.num_hidden_layers == 12:
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205])
elif bert_net_cfg.num_hidden_layers == 24:
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397])
D.init()
rank = args_opt.device_id % device_num
else:

View File

@ -16,8 +16,8 @@
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: sh run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json"
echo "bash run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: bash run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
@ -49,6 +49,10 @@ do
cp *.py ./LOG$i
cd ./LOG$i || exit
echo "start training for rank $i, device $DEVICE_ID"
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
env > env.log
taskset -c $cmdopt python ../run_pretrain.py \
--distribute="true" \
@ -59,7 +63,7 @@ do
--enable_lossscale="true" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--data_sink_steps=100 \
--checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \

View File

@ -16,8 +16,8 @@
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: sh run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json"
echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json"
echo "=============================================================================================================="
DEVICE_ID=$1
@ -25,6 +25,10 @@ EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python run_pretrain.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
@ -33,7 +37,7 @@ python run_pretrain.py \
--enable_lossscale="true" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--data_sink_steps=100 \
--checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \

View File

@ -357,10 +357,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
@ -411,10 +411,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:

View File

@ -25,6 +25,7 @@ from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from .fused_layer_norm import FusedLayerNorm
class BertConfig:
@ -77,7 +78,8 @@ class BertConfig:
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float32):
compute_type=mstype.float32,
enable_fused_layernorm=False):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
@ -96,6 +98,7 @@ class BertConfig:
self.use_relative_positions = use_relative_positions
self.dtype = dtype
self.compute_type = compute_type
self.enable_fused_layernorm = enable_fused_layernorm
class EmbeddingLookup(nn.Cell):
@ -240,13 +243,19 @@ class BertOutput(nn.Cell):
out_channels,
initializer_range=0.02,
dropout_prob=0.1,
compute_type=mstype.float32):
compute_type=mstype.float32,
enable_fused_layernorm=False):
super(BertOutput, self).__init__()
self.dense = nn.Dense(in_channels, out_channels,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
self.dropout_prob = dropout_prob
self.add = P.TensorAdd()
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
if compute_type == mstype.float16:
self.layernorm = FusedLayerNorm((out_channels,),
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
else:
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.cast = P.Cast()
def construct(self, hidden_status, input_tensor):
@ -481,12 +490,13 @@ class BertAttention(nn.Cell):
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
self.cast_compute_type = SaturateCast(dst_type=compute_type)
self._generate_relative_positions_embeddings = \
RelaPosEmbeddingsGenerator(length=to_seq_length,
depth=size_per_head,
max_relative_position=16,
initializer_range=initializer_range,
use_one_hot_embeddings=use_one_hot_embeddings)
if self.use_relative_positions:
self._generate_relative_positions_embeddings = \
RelaPosEmbeddingsGenerator(length=to_seq_length,
depth=size_per_head,
max_relative_position=16,
initializer_range=initializer_range,
use_one_hot_embeddings=use_one_hot_embeddings)
def construct(self, from_tensor, to_tensor, attention_mask):
# reshape 2d/3d input tensors to 2d
@ -529,7 +539,7 @@ class BertAttention(nn.Cell):
self.trans_shape_position)
attention_scores = attention_scores + key_position_scores_r_t
attention_scores = self.multiply(attention_scores, self.scores_mul)
attention_scores = self.multiply(self.scores_mul, attention_scores)
if self.has_attention_mask:
attention_mask = self.expand_dims(attention_mask, 1)
@ -606,7 +616,8 @@ class BertSelfAttention(nn.Cell):
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
compute_type=mstype.float32):
compute_type=mstype.float32,
enable_fused_layernorm=False):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
@ -634,7 +645,8 @@ class BertSelfAttention(nn.Cell):
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
@ -676,7 +688,8 @@ class BertEncoderCell(nn.Cell):
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32):
compute_type=mstype.float32,
enable_fused_layernorm=False):
super(BertEncoderCell, self).__init__()
self.attention = BertSelfAttention(
batch_size=batch_size,
@ -688,7 +701,8 @@ class BertEncoderCell(nn.Cell):
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
compute_type=compute_type)
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
self.intermediate = nn.Dense(in_channels=hidden_size,
out_channels=intermediate_size,
activation=hidden_act,
@ -697,7 +711,8 @@ class BertEncoderCell(nn.Cell):
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
def construct(self, hidden_states, attention_mask):
# self-attention
@ -744,7 +759,8 @@ class BertTransformer(nn.Cell):
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
return_all_encoders=False):
return_all_encoders=False,
enable_fused_layernorm=False):
super(BertTransformer, self).__init__()
self.return_all_encoders = return_all_encoders
@ -761,7 +777,8 @@ class BertTransformer(nn.Cell):
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
hidden_act=hidden_act,
compute_type=compute_type)
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
layers.append(layer)
self.layers = nn.CellList(layers)
@ -888,7 +905,8 @@ class BertModel(nn.Cell):
use_relative_positions=config.use_relative_positions,
hidden_act=config.hidden_act,
compute_type=config.compute_type,
return_all_encoders=True)
return_all_encoders=True,
enable_fused_layernorm=config.enable_fused_layernorm)
self.cast = P.Cast()
self.dtype = config.dtype

View File

@ -17,12 +17,12 @@
import json
import numpy as np
from evaluation_config import cfg
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from CRF import postprocess
import tokenization
from sample_process import label_generation, process_one_example_p
from .evaluation_config import cfg
from .CRF import postprocess
vocab_file = "./vocab.txt"
tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file)

View File

@ -17,16 +17,16 @@ network config setting, will be used in dataset.py, run_pretrain.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
from .bert_model import BertConfig
cfg = edict({
'bert_network': 'base',
'loss_scale_value': 2**32,
'loss_scale_value': 65536,
'scale_factor': 2,
'scale_window': 1000,
'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({
'learning_rate': 3e-5,
'end_learning_rate': 1e-7,
'end_learning_rate': 1e-10,
'power': 5.0,
'weight_decay': 1e-5,
'eps': 1e-6,
@ -34,7 +34,7 @@ cfg = edict({
}),
'Lamb': edict({
'start_learning_rate': 3e-5,
'end_learning_rate': 1e-7,
'end_learning_rate': 1e-10,
'power': 10.0,
'warmup_steps': 10000,
'weight_decay': 0.01,
@ -56,7 +56,7 @@ if cfg.bert_network == 'base':
bert_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=21128,
vocab_size=21136,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
@ -71,13 +71,13 @@ if cfg.bert_network == 'base':
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
compute_type=mstype.float16
)
if cfg.bert_network == 'nezha':
bert_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=21128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
@ -92,5 +92,27 @@ if cfg.bert_network == 'nezha':
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
compute_type=mstype.float16
)
if cfg.bert_network == 'large':
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=512,
vocab_size=30528,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
enable_fused_layernorm=True
)

View File

@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore import log as logger
from config import bert_net_cfg
from .config import bert_net_cfg
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
@ -31,8 +31,9 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
files = os.listdir(data_dir)
data_files = []
for file_name in files:
data_files.append(os.path.join(data_dir, file_name))
ds = de.TFRecordDataset(data_files, schema_dir,
if "tfrecord" in file_name:
data_files.append(os.path.join(data_dir, file_name))
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,

View File

@ -19,7 +19,7 @@ config settings, will be used in finetune.py
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
from .bert_model import BertConfig
cfg = edict({
'task': 'NER',

View File

@ -19,7 +19,7 @@ config settings, will be used in finetune.py
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
from .bert_model import BertConfig
cfg = edict({
'task': 'NER',

View File

@ -0,0 +1,121 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""fused layernorm"""
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
import mindspore.common.dtype as mstype
from mindspore.nn.cell import Cell
import numpy as np
__all__ = ['FusedLayerNorm']
@constexpr
def get_shape_for_norm(x_shape, begin_norm_axis):
print("input_shape: ", x_shape)
norm_shape = x_shape[begin_norm_axis:]
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
print("output_shape: ", output_shape)
return output_shape
class FusedLayerNorm(Cell):
r"""
Applies Layer Normalization over a mini-batch of inputs.
Layer normalization is widely used in recurrent neural networks. It applies
normalization over a mini-batch of inputs for each single training case as described
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
normalization, layer normalization performs exactly the same computation at training and
testing times. It can be described using the following formula. It is applied across all channels
and pixel but only one batch size.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
`begin_norm_axis ... R - 1`.
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
use_batch_nrom (bool): Whether use batchnorm to preocess.
Inputs:
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
Outputs:
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape()[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x)
"""
def __init__(self,
normalized_shape,
begin_norm_axis=-1,
begin_params_axis=-1,
gamma_init='ones',
beta_init='zeros',
use_batch_norm=False):
super(FusedLayerNorm, self).__init__()
if not isinstance(normalized_shape, (tuple, list)):
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
.format(normalized_shape, type(normalized_shape)))
self.normalized_shape = normalized_shape
self.begin_norm_axis = begin_norm_axis
self.begin_params_axis = begin_params_axis
self.gamma = Parameter(initializer(
gamma_init, normalized_shape), name="gamma")
self.beta = Parameter(initializer(
beta_init, normalized_shape), name="beta")
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
self.use_batch_norm = use_batch_norm
def construct(self, input_x):
if self.use_batch_norm and self.training:
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
shape_x = F.shape(input_x)
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
input_x = F.reshape(input_x, norm_shape)
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
output = F.reshape(output, shape_x)
y = output * self.gamma + self.beta
else:
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
return y
def extend_repr(self):
"""Display instance object as string."""
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
return s

View File

@ -30,8 +30,8 @@ from mindspore.train.parallel_utils import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
from CRF import CRF
from .bert_for_pre_training import clip_grad
from .CRF import CRF
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0

View File

@ -25,7 +25,8 @@ import mindspore.dataset.transforms.c_transforms as C
from mindspore import context
from mindspore import log as logger
from mindspore.common.tensor import Tensor
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell
from src.bert_model import BertConfig
from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell
from mindspore.nn.optim import Lamb
from mindspore.train.callback import Callback
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
@ -77,7 +78,8 @@ def get_config(version='base', batch_size=1):
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16)
compute_type=mstype.float16,
enable_fused_layernorm=False)
else:
bert_config = BertConfig(batch_size=batch_size)
return bert_config

View File

@ -0,0 +1,177 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
CRF script.
'''
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
import mindspore.common.dtype as mstype
class CRF(nn.Cell):
'''
Conditional Random Field
Args:
tag_to_index: The dict for tag to index mapping with extra "<START>" and "<STOP>"sign.
batch_size: Batch size, i.e., the length of the first dimension.
seq_length: Sequence length, i.e., the length of the second dimention.
is_training: Specifies whether to use training mode.
Returns:
Training mode: Tensor, total loss.
Evaluation mode: Tuple, the index for each step with the highest score; Tuple, the index for the last
step with the highest score.
'''
def __init__(self, tag_to_index, batch_size=1, seq_length=128, is_training=True):
super(CRF, self).__init__()
self.target_size = len(tag_to_index)
self.is_training = is_training
self.tag_to_index = tag_to_index
self.batch_size = batch_size
self.seq_length = seq_length
self.START_TAG = "<START>"
self.STOP_TAG = "<STOP>"
self.START_VALUE = Tensor(self.target_size-2, dtype=mstype.int32)
self.STOP_VALUE = Tensor(self.target_size-1, dtype=mstype.int32)
transitions = np.random.normal(size=(self.target_size, self.target_size)).astype(np.float32)
transitions[tag_to_index[self.START_TAG], :] = -10000
transitions[:, tag_to_index[self.STOP_TAG]] = -10000
self.transitions = Parameter(Tensor(transitions), name="transition_matrix")
self.cat = P.Concat(axis=-1)
self.argmax = P.ArgMaxWithValue(axis=-1)
self.log = P.Log()
self.exp = P.Exp()
self.sum = P.ReduceSum()
self.tile = P.Tile()
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.reshape = P.Reshape()
self.expand = P.ExpandDims()
self.mean = P.ReduceMean()
init_alphas = np.ones(shape=(self.batch_size, self.target_size)) * -10000.0
init_alphas[:, self.tag_to_index[self.START_TAG]] = 0.
self.init_alphas = Tensor(init_alphas, dtype=mstype.float32)
self.cast = P.Cast()
self.reduce_max = P.ReduceMax(keep_dims=True)
self.on_value = Tensor(1.0, dtype=mstype.float32)
self.off_value = Tensor(0.0, dtype=mstype.float32)
self.onehot = P.OneHot()
def log_sum_exp(self, logits):
'''
Compute the log_sum_exp score for normalization factor.
'''
max_score = self.reduce_max(logits, -1) #16 5 5
score = self.log(self.reduce_sum(self.exp(logits - max_score), -1))
score = max_score + score
return score
def _realpath_score(self, features, label):
'''
Compute the emission and transition score for the real path.
'''
label = label * 1
concat_A = self.tile(self.reshape(self.START_VALUE, (1,)), (self.batch_size,))
concat_A = self.reshape(concat_A, (self.batch_size, 1))
labels = self.cat((concat_A, label))
onehot_label = self.onehot(label, self.target_size, self.on_value, self.off_value)
emits = features * onehot_label
labels = self.onehot(labels, self.target_size, self.on_value, self.off_value)
label1 = labels[:, 1:, :]
label2 = labels[:, :self.seq_length, :]
label1 = self.expand(label1, 3)
label2 = self.expand(label2, 2)
label_trans = label1 * label2
transitions = self.expand(self.expand(self.transitions, 0), 0)
trans = transitions * label_trans
score = self.sum(emits, (1, 2)) + self.sum(trans, (1, 2, 3))
stop_value_index = labels[:, (self.seq_length-1):self.seq_length, :]
stop_value = self.transitions[(self.target_size-1):self.target_size, :]
stop_score = stop_value * self.reshape(stop_value_index, (self.batch_size, self.target_size))
score = score + self.sum(stop_score, 1)
score = self.reshape(score, (self.batch_size, -1))
return score
def _normalization_factor(self, features):
'''
Compute the total score for all the paths.
'''
forward_var = self.init_alphas
forward_var = self.expand(forward_var, 1)
for idx in range(self.seq_length):
feat = features[:, idx:(idx+1), :]
emit_score = self.reshape(feat, (self.batch_size, self.target_size, 1))
next_tag_var = emit_score + self.transitions + forward_var
forward_var = self.log_sum_exp(next_tag_var)
forward_var = self.reshape(forward_var, (self.batch_size, 1, self.target_size))
terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
alpha = self.log_sum_exp(terminal_var)
alpha = self.reshape(alpha, (self.batch_size, -1))
return alpha
def _decoder(self, features):
'''
Viterbi decode for evaluation.
'''
backpointers = ()
forward_var = self.init_alphas
for idx in range(self.seq_length):
feat = features[:, idx:(idx+1), :]
feat = self.reshape(feat, (self.batch_size, self.target_size))
bptrs_t = ()
next_tag_var = self.expand(forward_var, 1) + self.transitions
best_tag_id, best_tag_value = self.argmax(next_tag_var)
bptrs_t += (best_tag_id,)
forward_var = best_tag_value + feat
backpointers += (bptrs_t,)
terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
best_tag_id, _ = self.argmax(terminal_var)
return backpointers, best_tag_id
def construct(self, features, label):
if self.is_training:
forward_score = self._normalization_factor(features)
gold_score = self._realpath_score(features, label)
return_value = self.mean(forward_score - gold_score)
else:
path_list, tag = self._decoder(features)
return_value = path_list, tag
return return_value
def postprocess(backpointers, best_tag_id):
'''
Do postprocess
'''
best_tag_id = best_tag_id.asnumpy()
batch_size = len(best_tag_id)
best_path = []
for i in range(batch_size):
best_path.append([])
best_local_id = best_tag_id[i]
best_path[-1].append(best_local_id)
for bptrs_t in reversed(backpointers):
bptrs_t = bptrs_t[0].asnumpy()
local_idx = bptrs_t[i]
best_local_id = local_idx[best_local_id]
best_path[-1].append(best_local_id)
# Pop off the start tag (we dont want to return that to the caller)
best_path[-1].pop()
best_path[-1].reverse()
return best_path

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert Init."""
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
SaturateCast, CreateAttentionMaskFromInputMask
__all__ = [
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell",
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator",
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask"
]

View File

@ -0,0 +1,434 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert for pretraining."""
import numpy as np
import mindspore.nn as nn
from mindspore.common.initializer import initializer, TruncatedNormal
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from .bert_model import BertModel
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
_nn_clip_by_norm = nn.ClipByNorm()
clip_grad = C.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type != 0 and clip_type != 1:
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
class GetMaskedLMOutput(nn.Cell):
"""
Get masked lm output.
Args:
config (BertConfig): The config of BertModel.
Returns:
Tensor, masked lm output.
"""
def __init__(self, config):
super(GetMaskedLMOutput, self).__init__()
self.width = config.hidden_size
self.reshape = P.Reshape()
self.gather = P.GatherV2()
weight_init = TruncatedNormal(config.initializer_range)
self.dense = nn.Dense(self.width,
config.hidden_size,
weight_init=weight_init,
activation=config.hidden_act).to_float(config.compute_type)
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
self.output_bias = Parameter(
initializer(
'zero',
config.vocab_size),
name='output_bias')
self.matmul = P.MatMul(transpose_b=True)
self.log_softmax = nn.LogSoftmax(axis=-1)
self.shape_flat_offsets = (-1, 1)
self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32))
self.last_idx = (-1,)
self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width)
self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32))
self.cast = P.Cast()
self.compute_type = config.compute_type
self.dtype = config.dtype
def construct(self,
input_tensor,
output_weights,
positions):
flat_offsets = self.reshape(
self.rng * self.seq_length_tensor, self.shape_flat_offsets)
flat_position = self.reshape(positions + flat_offsets, self.last_idx)
flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
input_tensor = self.gather(flat_sequence_tensor, flat_position, 0)
input_tensor = self.cast(input_tensor, self.compute_type)
output_weights = self.cast(output_weights, self.compute_type)
input_tensor = self.dense(input_tensor)
input_tensor = self.layernorm(input_tensor)
logits = self.matmul(input_tensor, output_weights)
logits = self.cast(logits, self.dtype)
logits = logits + self.output_bias
log_probs = self.log_softmax(logits)
return log_probs
class GetNextSentenceOutput(nn.Cell):
"""
Get next sentence output.
Args:
config (BertConfig): The config of Bert.
Returns:
Tensor, next sentence output.
"""
def __init__(self, config):
super(GetNextSentenceOutput, self).__init__()
self.log_softmax = P.LogSoftmax()
self.weight_init = TruncatedNormal(config.initializer_range)
self.dense = nn.Dense(config.hidden_size, 2,
weight_init=self.weight_init, has_bias=True).to_float(config.compute_type)
self.dtype = config.dtype
self.cast = P.Cast()
def construct(self, input_tensor):
logits = self.dense(input_tensor)
logits = self.cast(logits, self.dtype)
log_prob = self.log_softmax(logits)
return log_prob
class BertPreTraining(nn.Cell):
"""
Bert pretraining network.
Args:
config (BertConfig): The config of BertModel.
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
Returns:
Tensor, prediction_scores, seq_relationship_score.
"""
def __init__(self, config, is_training, use_one_hot_embeddings):
super(BertPreTraining, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cls1 = GetMaskedLMOutput(config)
self.cls2 = GetNextSentenceOutput(config)
def construct(self, input_ids, input_mask, token_type_id,
masked_lm_positions):
sequence_output, pooled_output, embedding_table = \
self.bert(input_ids, token_type_id, input_mask)
prediction_scores = self.cls1(sequence_output,
embedding_table,
masked_lm_positions)
seq_relationship_score = self.cls2(pooled_output)
return prediction_scores, seq_relationship_score
class BertPretrainingLoss(nn.Cell):
"""
Provide bert pre-training loss.
Args:
config (BertConfig): The config of BertModel.
Returns:
Tensor, total loss.
"""
def __init__(self, config):
super(BertPretrainingLoss, self).__init__()
self.vocab_size = config.vocab_size
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.reshape = P.Reshape()
self.last_idx = (-1,)
self.neg = P.Neg()
self.cast = P.Cast()
def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids,
masked_lm_weights, next_sentence_labels):
"""Defines the computation performed."""
label_ids = self.reshape(masked_lm_ids, self.last_idx)
label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32)
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
numerator = self.reduce_sum(label_weights * per_example_loss, ())
denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32)
masked_lm_loss = numerator / denominator
# next_sentence_loss
labels = self.reshape(next_sentence_labels, self.last_idx)
one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(
one_hot_labels * seq_relationship_score, self.last_idx))
next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx)
# total_loss
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
class BertNetworkWithLoss(nn.Cell):
"""
Provide bert pre-training loss through network.
Args:
config (BertConfig): The config of BertModel.
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
Returns:
Tensor, the loss of the network.
"""
def __init__(self, config, is_training, use_one_hot_embeddings=False):
super(BertNetworkWithLoss, self).__init__()
self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings)
self.loss = BertPretrainingLoss(config)
self.cast = P.Cast()
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights):
prediction_scores, seq_relationship_score = \
self.bert(input_ids, input_mask, token_type_id, masked_lm_positions)
total_loss = self.loss(prediction_scores, seq_relationship_score,
masked_lm_ids, masked_lm_weights, next_sentence_labels)
return self.cast(total_loss, mstype.float32)
class BertTrainOneStepCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""
def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.cast = P.Cast()
self.hyper_map = C.HyperMap()
def set_sens(self, value):
self.sens = value
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
succ = self.optimizer(grads)
return F.depend(loss, succ)
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class BertTrainOneStepWithLossScaleCell(nn.Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad',
get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.add_flags(has_effect=True)
def construct(self,
input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
next_sentence_labels,
masked_lm_positions,
masked_lm_ids,
masked_lm_weights,
self.cast(scaling_sens,
mstype.float32))
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)

View File

@ -0,0 +1,949 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bert model."""
import math
import copy
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore.common.initializer import TruncatedNormal, initializer
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from .fused_layer_norm import FusedLayerNorm
class BertConfig:
"""
Configuration for `BertModel`.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): Length of input sequence. Default: 128.
vocab_size (int): The shape of each embedding vector. Default: 32000.
hidden_size (int): Size of the bert encoder layers. Default: 768.
num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
cell. Default: 12.
num_attention_heads (int): Number of attention heads in the BertTransformer
encoder cell. Default: 12.
intermediate_size (int): Size of intermediate layer in the BertTransformer
encoder cell. Default: 3072.
hidden_act (str): Activation function used in the BertTransformer encoder
cell. Default: "gelu".
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.1.
max_position_embeddings (int): Maximum length of sequences used in this
model. Default: 512.
type_vocab_size (int): Size of token type vocab. Default: 16.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
dataset. Default: True.
token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded
from dataset. Default: True.
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
"""
def __init__(self,
batch_size,
seq_length=128,
vocab_size=32000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float32,
enable_fused_layernorm=False):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.input_mask_from_dataset = input_mask_from_dataset
self.token_type_ids_from_dataset = token_type_ids_from_dataset
self.use_relative_positions = use_relative_positions
self.dtype = dtype
self.compute_type = compute_type
self.enable_fused_layernorm = enable_fused_layernorm
class EmbeddingLookup(nn.Cell):
"""
A embeddings lookup table with a fixed dictionary and size.
Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
each embedding vector.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
"""
def __init__(self,
vocab_size,
embedding_size,
embedding_shape,
use_one_hot_embeddings=False,
initializer_range=0.02):
super(EmbeddingLookup, self).__init__()
self.vocab_size = vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(initializer_range),
[vocab_size, embedding_size]),
name='embedding_table')
self.expand = P.ExpandDims()
self.shape_flat = (-1,)
self.gather = P.GatherV2()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = tuple(embedding_shape)
def construct(self, input_ids):
extended_ids = self.expand(input_ids, -1)
flat_ids = self.reshape(extended_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
output_for_reshape = self.array_mul(
one_hot_ids, self.embedding_table)
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
output = self.reshape(output_for_reshape, self.shape)
return output, self.embedding_table
class EmbeddingPostprocessor(nn.Cell):
"""
Postprocessors apply positional and token type embeddings to word embeddings.
Args:
embedding_size (int): The size of each embedding vector.
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
each embedding vector.
use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
token_type_vocab_size (int): Size of token type vocab. Default: 16.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
max_position_embeddings (int): Maximum length of sequences used in this
model. Default: 512.
dropout_prob (float): The dropout probability. Default: 0.1.
"""
def __init__(self,
embedding_size,
embedding_shape,
use_relative_positions=False,
use_token_type=False,
token_type_vocab_size=16,
use_one_hot_embeddings=False,
initializer_range=0.02,
max_position_embeddings=512,
dropout_prob=0.1):
super(EmbeddingPostprocessor, self).__init__()
self.use_token_type = use_token_type
self.token_type_vocab_size = token_type_vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.max_position_embeddings = max_position_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(initializer_range),
[token_type_vocab_size,
embedding_size]),
name='embedding_table')
self.shape_flat = (-1,)
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.1, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = tuple(embedding_shape)
self.layernorm = nn.LayerNorm((embedding_size,))
self.dropout = nn.Dropout(1 - dropout_prob)
self.gather = P.GatherV2()
self.use_relative_positions = use_relative_positions
self.slice = P.StridedSlice()
self.full_position_embeddings = Parameter(initializer
(TruncatedNormal(initializer_range),
[max_position_embeddings,
embedding_size]),
name='full_position_embeddings')
def construct(self, token_type_ids, word_embeddings):
output = word_embeddings
if self.use_token_type:
flat_ids = self.reshape(token_type_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids,
self.token_type_vocab_size, self.on_value, self.off_value)
token_type_embeddings = self.array_mul(one_hot_ids,
self.embedding_table)
else:
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
output += token_type_embeddings
if not self.use_relative_positions:
_, seq, width = self.shape
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
output += position_embeddings
output = self.layernorm(output)
output = self.dropout(output)
return output
class BertOutput(nn.Cell):
"""
Apply a linear computation to hidden status and a residual computation to input.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
dropout_prob (float): The dropout probability. Default: 0.1.
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
"""
def __init__(self,
in_channels,
out_channels,
initializer_range=0.02,
dropout_prob=0.1,
compute_type=mstype.float32,
enable_fused_layernorm=False):
super(BertOutput, self).__init__()
self.dense = nn.Dense(in_channels, out_channels,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
self.dropout_prob = dropout_prob
self.add = P.TensorAdd()
if compute_type == mstype.float16:
self.layernorm = FusedLayerNorm((out_channels,),
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
else:
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
self.cast = P.Cast()
def construct(self, hidden_status, input_tensor):
output = self.dense(hidden_status)
output = self.dropout(output)
output = self.add(output, input_tensor)
output = self.layernorm(output)
return output
class RelaPosMatrixGenerator(nn.Cell):
"""
Generates matrix of relative positions between inputs.
Args:
length (int): Length of one dim for the matrix to be generated.
max_relative_position (int): Max value of relative position.
"""
def __init__(self, length, max_relative_position):
super(RelaPosMatrixGenerator, self).__init__()
self._length = length
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
self.range_length = -length + 1
self.tile = P.Tile()
self.range_mat = P.Reshape()
self.sub = P.Sub()
self.expanddims = P.ExpandDims()
self.cast = P.Cast()
def construct(self):
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
tile_row_out = self.tile(range_vec_row_out, (self._length,))
tile_col_out = self.tile(range_vec_col_out, (1, self._length))
range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
distance_mat = self.sub(range_mat_out, transpose_out)
distance_mat_clipped = C.clip_by_value(distance_mat,
self._min_relative_position,
self._max_relative_position)
# Shift values to be >=0. Each integer still uniquely identifies a
# relative position difference.
final_mat = distance_mat_clipped + self._max_relative_position
return final_mat
class RelaPosEmbeddingsGenerator(nn.Cell):
"""
Generates tensor of size [length, length, depth].
Args:
length (int): Length of one dim for the matrix to be generated.
depth (int): Size of each attention head.
max_relative_position (int): Maxmum value of relative position.
initializer_range (float): Initialization value of TruncatedNormal.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
"""
def __init__(self,
length,
depth,
max_relative_position,
initializer_range,
use_one_hot_embeddings=False):
super(RelaPosEmbeddingsGenerator, self).__init__()
self.depth = depth
self.vocab_size = max_relative_position * 2 + 1
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embeddings_table = Parameter(
initializer(TruncatedNormal(initializer_range),
[self.vocab_size, self.depth]),
name='embeddings_for_position')
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
max_relative_position=max_relative_position)
self.reshape = P.Reshape()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.shape = P.Shape()
self.gather = P.GatherV2() # index_select
self.matmul = P.BatchMatMul()
def construct(self):
relative_positions_matrix_out = self.relative_positions_matrix()
# Generate embedding for each relative position of dimension depth.
if self.use_one_hot_embeddings:
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
one_hot_relative_positions_matrix = self.one_hot(
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
embeddings = self.reshape(embeddings, my_shape)
else:
embeddings = self.gather(self.embeddings_table,
relative_positions_matrix_out, 0)
return embeddings
class SaturateCast(nn.Cell):
"""
Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
the danger that the value will overflow or underflow.
Args:
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
"""
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
super(SaturateCast, self).__init__()
np_type = mstype.dtype_to_nptype(dst_type)
min_type = np.finfo(np_type).min
max_type = np.finfo(np_type).max
self.tensor_min_type = Tensor([min_type], dtype=src_type)
self.tensor_max_type = Tensor([max_type], dtype=src_type)
self.min_op = P.Minimum()
self.max_op = P.Maximum()
self.cast = P.Cast()
self.dst_type = dst_type
def construct(self, x):
out = self.max_op(x, self.tensor_min_type)
out = self.min_op(out, self.tensor_max_type)
return self.cast(out, self.dst_type)
class BertAttention(nn.Cell):
"""
Apply multi-headed attention from "from_tensor" to "to_tensor".
Args:
batch_size (int): Batch size of input datasets.
from_tensor_width (int): Size of last dim of from_tensor.
to_tensor_width (int): Size of last dim of to_tensor.
from_seq_length (int): Length of from_tensor sequence.
to_seq_length (int): Length of to_tensor sequence.
num_attention_heads (int): Number of attention heads. Default: 1.
size_per_head (int): Size of each attention head. Default: 512.
query_act (str): Activation function for the query transform. Default: None.
key_act (str): Activation function for the key transform. Default: None.
value_act (str): Activation function for the value transform. Default: None.
has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.0.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
tensor. Default: False.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
from_tensor_width,
to_tensor_width,
from_seq_length,
to_seq_length,
num_attention_heads=1,
size_per_head=512,
query_act=None,
key_act=None,
value_act=None,
has_attention_mask=False,
attention_probs_dropout_prob=0.0,
use_one_hot_embeddings=False,
initializer_range=0.02,
do_return_2d_tensor=False,
use_relative_positions=False,
compute_type=mstype.float32):
super(BertAttention, self).__init__()
self.batch_size = batch_size
self.from_seq_length = from_seq_length
self.to_seq_length = to_seq_length
self.num_attention_heads = num_attention_heads
self.size_per_head = size_per_head
self.has_attention_mask = has_attention_mask
self.use_relative_positions = use_relative_positions
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
self.reshape = P.Reshape()
self.shape_from_2d = (-1, from_tensor_width)
self.shape_to_2d = (-1, to_tensor_width)
weight = TruncatedNormal(initializer_range)
units = num_attention_heads * size_per_head
self.query_layer = nn.Dense(from_tensor_width,
units,
activation=query_act,
weight_init=weight).to_float(compute_type)
self.key_layer = nn.Dense(to_tensor_width,
units,
activation=key_act,
weight_init=weight).to_float(compute_type)
self.value_layer = nn.Dense(to_tensor_width,
units,
activation=value_act,
weight_init=weight).to_float(compute_type)
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
self.shape_to = (
batch_size, to_seq_length, num_attention_heads, size_per_head)
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.multiply = P.Mul()
self.transpose = P.Transpose()
self.trans_shape = (0, 2, 1, 3)
self.trans_shape_relative = (2, 0, 1, 3)
self.trans_shape_position = (1, 2, 0, 3)
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
self.batch_num = batch_size * num_attention_heads
self.matmul = P.BatchMatMul()
self.softmax = nn.Softmax()
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
if self.has_attention_mask:
self.expand_dims = P.ExpandDims()
self.sub = P.Sub()
self.add = P.TensorAdd()
self.cast = P.Cast()
self.get_dtype = P.DType()
if do_return_2d_tensor:
self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
else:
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
self.cast_compute_type = SaturateCast(dst_type=compute_type)
if self.use_relative_positions:
self._generate_relative_positions_embeddings = \
RelaPosEmbeddingsGenerator(length=to_seq_length,
depth=size_per_head,
max_relative_position=16,
initializer_range=initializer_range,
use_one_hot_embeddings=use_one_hot_embeddings)
def construct(self, from_tensor, to_tensor, attention_mask):
# reshape 2d/3d input tensors to 2d
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
query_out = self.query_layer(from_tensor_2d)
key_out = self.key_layer(to_tensor_2d)
value_out = self.value_layer(to_tensor_2d)
query_layer = self.reshape(query_out, self.shape_from)
query_layer = self.transpose(query_layer, self.trans_shape)
key_layer = self.reshape(key_out, self.shape_to)
key_layer = self.transpose(key_layer, self.trans_shape)
attention_scores = self.matmul_trans_b(query_layer, key_layer)
# use_relative_position, supplementary logic
if self.use_relative_positions:
# 'relations_keys' = [F|T, F|T, H]
relations_keys = self._generate_relative_positions_embeddings()
relations_keys = self.cast_compute_type(relations_keys)
# query_layer_t is [F, B, N, H]
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
# query_layer_r is [F, B * N, H]
query_layer_r = self.reshape(query_layer_t,
(self.from_seq_length,
self.batch_num,
self.size_per_head))
# key_position_scores is [F, B * N, F|T]
key_position_scores = self.matmul_trans_b(query_layer_r,
relations_keys)
# key_position_scores_r is [F, B, N, F|T]
key_position_scores_r = self.reshape(key_position_scores,
(self.from_seq_length,
self.batch_size,
self.num_attention_heads,
self.from_seq_length))
# key_position_scores_r_t is [B, N, F, F|T]
key_position_scores_r_t = self.transpose(key_position_scores_r,
self.trans_shape_position)
attention_scores = attention_scores + key_position_scores_r_t
attention_scores = self.multiply(self.scores_mul, attention_scores)
if self.has_attention_mask:
attention_mask = self.expand_dims(attention_mask, 1)
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
self.cast(attention_mask, self.get_dtype(attention_scores)))
adder = self.multiply(multiply_out, self.multiply_data)
attention_scores = self.add(adder, attention_scores)
attention_probs = self.softmax(attention_scores)
attention_probs = self.dropout(attention_probs)
value_layer = self.reshape(value_out, self.shape_to)
value_layer = self.transpose(value_layer, self.trans_shape)
context_layer = self.matmul(attention_probs, value_layer)
# use_relative_position, supplementary logic
if self.use_relative_positions:
# 'relations_values' = [F|T, F|T, H]
relations_values = self._generate_relative_positions_embeddings()
relations_values = self.cast_compute_type(relations_values)
# attention_probs_t is [F, B, N, T]
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
# attention_probs_r is [F, B * N, T]
attention_probs_r = self.reshape(
attention_probs_t,
(self.from_seq_length,
self.batch_num,
self.to_seq_length))
# value_position_scores is [F, B * N, H]
value_position_scores = self.matmul(attention_probs_r,
relations_values)
# value_position_scores_r is [F, B, N, H]
value_position_scores_r = self.reshape(value_position_scores,
(self.from_seq_length,
self.batch_size,
self.num_attention_heads,
self.size_per_head))
# value_position_scores_r_t is [B, N, F, H]
value_position_scores_r_t = self.transpose(value_position_scores_r,
self.trans_shape_position)
context_layer = context_layer + value_position_scores_r_t
context_layer = self.transpose(context_layer, self.trans_shape)
context_layer = self.reshape(context_layer, self.shape_return)
return context_layer
class BertSelfAttention(nn.Cell):
"""
Apply self-attention.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): Length of input sequence.
hidden_size (int): Size of the bert encoder layers.
num_attention_heads (int): Number of attention heads. Default: 12.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
seq_length,
hidden_size,
num_attention_heads=12,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
compute_type=mstype.float32,
enable_fused_layernorm=False):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
"of attention heads (%d)" % (hidden_size, num_attention_heads))
self.size_per_head = int(hidden_size / num_attention_heads)
self.attention = BertAttention(
batch_size=batch_size,
from_tensor_width=hidden_size,
to_tensor_width=hidden_size,
from_seq_length=seq_length,
to_seq_length=seq_length,
num_attention_heads=num_attention_heads,
size_per_head=self.size_per_head,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
use_relative_positions=use_relative_positions,
has_attention_mask=True,
do_return_2d_tensor=True,
compute_type=compute_type)
self.output = BertOutput(in_channels=hidden_size,
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
def construct(self, input_tensor, attention_mask):
input_tensor = self.reshape(input_tensor, self.shape)
attention_output = self.attention(input_tensor, input_tensor, attention_mask)
output = self.output(attention_output, input_tensor)
return output
class BertEncoderCell(nn.Cell):
"""
Encoder cells used in BertTransformer.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the bert encoder layers. Default: 768.
seq_length (int): Length of input sequence. Default: 512.
num_attention_heads (int): Number of attention heads. Default: 12.
intermediate_size (int): Size of intermediate layer. Default: 3072.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.02.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
hidden_act (str): Activation function. Default: "gelu".
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
hidden_size=768,
seq_length=512,
num_attention_heads=12,
intermediate_size=3072,
attention_probs_dropout_prob=0.02,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
enable_fused_layernorm=False):
super(BertEncoderCell, self).__init__()
self.attention = BertSelfAttention(
batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
self.intermediate = nn.Dense(in_channels=hidden_size,
out_channels=intermediate_size,
activation=hidden_act,
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
self.output = BertOutput(in_channels=intermediate_size,
out_channels=hidden_size,
initializer_range=initializer_range,
dropout_prob=hidden_dropout_prob,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
def construct(self, hidden_states, attention_mask):
# self-attention
attention_output = self.attention(hidden_states, attention_mask)
# feed construct
intermediate_output = self.intermediate(attention_output)
# add and normalize
output = self.output(intermediate_output, attention_output)
return output
class BertTransformer(nn.Cell):
"""
Multi-layer bert transformer.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the encoder layers.
seq_length (int): Length of input sequence.
num_hidden_layers (int): Number of hidden layers in encoder cells.
num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
attention_probs_dropout_prob (float): The dropout probability for
BertAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
"""
def __init__(self,
batch_size,
hidden_size,
seq_length,
num_hidden_layers,
num_attention_heads=12,
intermediate_size=3072,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
return_all_encoders=False,
enable_fused_layernorm=False):
super(BertTransformer, self).__init__()
self.return_all_encoders = return_all_encoders
layers = []
for _ in range(num_hidden_layers):
layer = BertEncoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
use_relative_positions=use_relative_positions,
hidden_act=hidden_act,
compute_type=compute_type,
enable_fused_layernorm=enable_fused_layernorm)
layers.append(layer)
self.layers = nn.CellList(layers)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
self.out_shape = (batch_size, seq_length, hidden_size)
def construct(self, input_tensor, attention_mask):
prev_output = self.reshape(input_tensor, self.shape)
all_encoder_layers = ()
for layer_module in self.layers:
layer_output = layer_module(prev_output, attention_mask)
prev_output = layer_output
if self.return_all_encoders:
layer_output = self.reshape(layer_output, self.out_shape)
all_encoder_layers = all_encoder_layers + (layer_output,)
if not self.return_all_encoders:
prev_output = self.reshape(prev_output, self.out_shape)
all_encoder_layers = all_encoder_layers + (prev_output,)
return all_encoder_layers
class CreateAttentionMaskFromInputMask(nn.Cell):
"""
Create attention mask according to input mask.
Args:
config (Class): Configuration for BertModel.
"""
def __init__(self, config):
super(CreateAttentionMaskFromInputMask, self).__init__()
self.input_mask_from_dataset = config.input_mask_from_dataset
self.input_mask = None
if not self.input_mask_from_dataset:
self.input_mask = initializer(
"ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor()
self.cast = P.Cast()
self.reshape = P.Reshape()
self.shape = (config.batch_size, 1, config.seq_length)
self.broadcast_ones = initializer(
"ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor()
self.batch_matmul = P.BatchMatMul()
def construct(self, input_mask):
if not self.input_mask_from_dataset:
input_mask = self.input_mask
input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
attention_mask = self.batch_matmul(self.broadcast_ones, input_mask)
return attention_mask
class BertModel(nn.Cell):
"""
Bidirectional Encoder Representations from Transformers.
Args:
config (Class): Configuration for BertModel.
is_training (bool): True for training mode. False for eval mode.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
"""
def __init__(self,
config,
is_training,
use_one_hot_embeddings=False):
super(BertModel, self).__init__()
config = copy.deepcopy(config)
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
self.input_mask_from_dataset = config.input_mask_from_dataset
self.token_type_ids_from_dataset = config.token_type_ids_from_dataset
self.batch_size = config.batch_size
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.embedding_size = config.hidden_size
self.token_type_ids = None
self.last_idx = self.num_hidden_layers - 1
output_embedding_shape = [self.batch_size, self.seq_length,
self.embedding_size]
if not self.token_type_ids_from_dataset:
self.token_type_ids = initializer(
"zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor()
self.bert_embedding_lookup = EmbeddingLookup(
vocab_size=config.vocab_size,
embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range)
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape,
use_relative_positions=config.use_relative_positions,
use_token_type=True,
token_type_vocab_size=config.type_vocab_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
self.bert_encoder = BertTransformer(
batch_size=self.batch_size,
hidden_size=self.hidden_size,
seq_length=self.seq_length,
num_attention_heads=config.num_attention_heads,
num_hidden_layers=self.num_hidden_layers,
intermediate_size=config.intermediate_size,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range,
hidden_dropout_prob=config.hidden_dropout_prob,
use_relative_positions=config.use_relative_positions,
hidden_act=config.hidden_act,
compute_type=config.compute_type,
return_all_encoders=True,
enable_fused_layernorm=config.enable_fused_layernorm)
self.cast = P.Cast()
self.dtype = config.dtype
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
self.slice = P.StridedSlice()
self.squeeze_1 = P.Squeeze(axis=1)
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
activation="tanh",
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
def construct(self, input_ids, token_type_ids, input_mask):
# embedding
if not self.token_type_ids_from_dataset:
token_type_ids = self.token_type_ids
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
# bert encoder
encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output),
attention_mask)
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
# pooler
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(self.batch_size, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)
pooled_output = self.cast(pooled_output, self.dtype)
return sequence_output, pooled_output, embedding_tables

View File

@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''bert clue evaluation'''
import json
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
import tokenization
from sample_process import label_generation, process_one_example_p
from .evaluation_config import cfg
from .CRF import postprocess
vocab_file = "./vocab.txt"
tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file)
def process(model, text, sequence_length):
"""
process text.
"""
data = [text]
features = []
res = []
ids = []
for i in data:
feature = process_one_example_p(tokenizer_, i, max_seq_len=sequence_length)
features.append(feature)
input_ids, input_mask, token_type_id = feature
input_ids = Tensor(np.array(input_ids), mstype.int32)
input_mask = Tensor(np.array(input_mask), mstype.int32)
token_type_id = Tensor(np.array(token_type_id), mstype.int32)
if cfg.use_crf:
backpointers, best_tag_id = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
best_path = postprocess(backpointers, best_tag_id)
logits = []
for ele in best_path:
logits.extend(ele)
ids = logits
else:
logits = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
ids = logits.asnumpy()
ids = np.argmax(ids, axis=-1)
ids = list(ids)
res = label_generation(text, ids)
return res
def submit(model, path, sequence_length):
"""
submit task
"""
data = []
for line in open(path):
if not line.strip():
continue
oneline = json.loads(line.strip())
res = process(model, oneline["text"], sequence_length)
print("text", oneline["text"])
print("res:", res)
data.append(json.dumps({"label": res}, ensure_ascii=False))
open("ner_predict.json", "w").write("\n".join(data))

View File

@ -0,0 +1,118 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in dataset.py, run_pretrain.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from .bert_model import BertConfig
cfg = edict({
'bert_network': 'base',
'loss_scale_value': 65536,
'scale_factor': 2,
'scale_window': 1000,
'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({
'learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'power': 5.0,
'weight_decay': 1e-5,
'eps': 1e-6,
'warmup_steps': 10000,
}),
'Lamb': edict({
'start_learning_rate': 3e-5,
'end_learning_rate': 1e-10,
'power': 10.0,
'warmup_steps': 10000,
'weight_decay': 0.01,
'eps': 1e-6,
}),
'Momentum': edict({
'learning_rate': 2e-5,
'momentum': 0.9,
}),
})
'''
Including two kinds of network: \
base: Goole BERT-base(the base version of BERT model).
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
Functional Relative Posetional Encoding as an effective positional encoding scheme).
'''
if cfg.bert_network == 'base':
bert_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=21136,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16
)
if cfg.bert_network == 'nezha':
bert_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16
)
if cfg.bert_network == 'large':
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=512,
vocab_size=30528,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
enable_fused_layernorm=True
)

View File

@ -0,0 +1,59 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in run_pretrain.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
from mindspore import log as logger
from .config import bert_net_cfg
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
data_sink_steps=1, data_dir=None, schema_dir=None):
"""create train dataset"""
# apply repeat operations
repeat_count = epoch_size
files = os.listdir(data_dir)
data_files = []
for file_name in files:
if "tfrecord" in file_name:
data_files.append(os.path.join(data_dir, file_name))
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
ds.set_dataset_size(new_size)
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(new_repeat_count)
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count

View File

@ -0,0 +1,53 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
config settings, will be used in finetune.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from .bert_model import BertConfig
cfg = edict({
'task': 'NER',
'num_labels': 41,
'data_file': '/your/path/evaluation.tfrecord',
'schema_file': '/your/path/schema.json',
'finetune_ckpt': '/your/path/your.ckpt',
'use_crf': False,
'clue_benchmark': False,
})
bert_net_cfg = BertConfig(
batch_size=16 if not cfg.clue_benchmark else 1,
seq_length=128,
vocab_size=21128,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)

View File

@ -0,0 +1,119 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
config settings, will be used in finetune.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from .bert_model import BertConfig
cfg = edict({
'task': 'NER',
'num_labels': 41,
'data_file': '/your/path/train.tfrecord',
'schema_file': '/your/path/schema.json',
'epoch_num': 5,
'ckpt_prefix': 'bert',
'ckpt_dir': None,
'pre_training_ckpt': '/your/path/pre_training.ckpt',
'use_crf': False,
'optimizer': 'Lamb',
'AdamWeightDecayDynamicLR': edict({
'learning_rate': 2e-5,
'end_learning_rate': 1e-7,
'power': 1.0,
'weight_decay': 1e-5,
'eps': 1e-6,
}),
'Lamb': edict({
'start_learning_rate': 2e-5,
'end_learning_rate': 1e-7,
'power': 1.0,
'decay_filter': lambda x: False,
}),
'Momentum': edict({
'learning_rate': 2e-5,
'momentum': 0.9,
}),
})
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21128,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)
tag_to_index = {
"O": 0,
"S_address": 1,
"B_address": 2,
"M_address": 3,
"E_address": 4,
"S_book": 5,
"B_book": 6,
"M_book": 7,
"E_book": 8,
"S_company": 9,
"B_company": 10,
"M_company": 11,
"E_company": 12,
"S_game": 13,
"B_game": 14,
"M_game": 15,
"E_game": 16,
"S_government": 17,
"B_government": 18,
"M_government": 19,
"E_government": 20,
"S_movie": 21,
"B_movie": 22,
"M_movie": 23,
"E_movie": 24,
"S_name": 25,
"B_name": 26,
"M_name": 27,
"E_name": 28,
"S_organization": 29,
"B_organization": 30,
"M_organization": 31,
"E_organization": 32,
"S_position": 33,
"B_position": 34,
"M_position": 35,
"E_position": 36,
"S_scene": 37,
"B_scene": 38,
"M_scene": 39,
"E_scene": 40,
"<START>": 41,
"<STOP>": 42
}

View File

@ -0,0 +1,121 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""fused layernorm"""
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
import mindspore.common.dtype as mstype
from mindspore.nn.cell import Cell
import numpy as np
__all__ = ['FusedLayerNorm']
@constexpr
def get_shape_for_norm(x_shape, begin_norm_axis):
print("input_shape: ", x_shape)
norm_shape = x_shape[begin_norm_axis:]
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
print("output_shape: ", output_shape)
return output_shape
class FusedLayerNorm(Cell):
r"""
Applies Layer Normalization over a mini-batch of inputs.
Layer normalization is widely used in recurrent neural networks. It applies
normalization over a mini-batch of inputs for each single training case as described
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
normalization, layer normalization performs exactly the same computation at training and
testing times. It can be described using the following formula. It is applied across all channels
and pixel but only one batch size.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
`begin_norm_axis ... R - 1`.
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
use_batch_nrom (bool): Whether use batchnorm to preocess.
Inputs:
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
Outputs:
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape()[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x)
"""
def __init__(self,
normalized_shape,
begin_norm_axis=-1,
begin_params_axis=-1,
gamma_init='ones',
beta_init='zeros',
use_batch_norm=False):
super(FusedLayerNorm, self).__init__()
if not isinstance(normalized_shape, (tuple, list)):
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
.format(normalized_shape, type(normalized_shape)))
self.normalized_shape = normalized_shape
self.begin_norm_axis = begin_norm_axis
self.begin_params_axis = begin_params_axis
self.gamma = Parameter(initializer(
gamma_init, normalized_shape), name="gamma")
self.beta = Parameter(initializer(
beta_init, normalized_shape), name="beta")
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
self.use_batch_norm = use_batch_norm
def construct(self, input_x):
if self.use_batch_norm and self.training:
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
shape_x = F.shape(input_x)
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
input_x = F.reshape(input_x, norm_shape)
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
output = F.reshape(output, shape_x)
y = output * self.gamma + self.beta
else:
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
return y
def extend_repr(self):
"""Display instance object as string."""
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
return s

View File

@ -0,0 +1,100 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""process txt"""
import re
import json
def process_one_example_p(tokenizer, text, max_seq_len=128):
"""process one testline"""
textlist = list(text)
tokens = []
for _, word in enumerate(textlist):
token = tokenizer.tokenize(word)
tokens.extend(token)
if len(tokens) >= max_seq_len - 1:
tokens = tokens[0:(max_seq_len - 2)]
ntokens = []
segment_ids = []
label_ids = []
ntokens.append("[CLS]")
segment_ids.append(0)
for _, token in enumerate(tokens):
ntokens.append(token)
segment_ids.append(0)
ntokens.append("[SEP]")
segment_ids.append(0)
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_len:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(0)
ntokens.append("**NULL**")
assert len(input_ids) == max_seq_len
assert len(input_mask) == max_seq_len
assert len(segment_ids) == max_seq_len
feature = (input_ids, input_mask, segment_ids)
return feature
def label_generation(text, probs):
"""generate label"""
data = [text]
probs = [probs]
result = []
label2id = json.loads(open("./label2id.json").read())
id2label = [k for k, v in label2id.items()]
for index, prob in enumerate(probs):
for v in prob[1:len(data[index]) + 1]:
result.append(id2label[int(v)])
labels = {}
start = None
index = 0
for _, t in zip("".join(data), result):
if re.search("^[BS]", t):
if start is not None:
label = result[index - 1][2:]
if labels.get(label):
te_ = text[start:index]
labels[label][te_] = [[start, index - 1]]
else:
te_ = text[start:index]
labels[label] = {te_: [[start, index - 1]]}
start = index
if re.search("^O", t):
if start is not None:
label = result[index - 1][2:]
if labels.get(label):
te_ = text[start:index]
labels[label][te_] = [[start, index - 1]]
else:
te_ = text[start:index]
labels[label] = {te_: [[start, index - 1]]}
start = None
index += 1
if start is not None:
label = result[start][2:]
if labels.get(label):
te_ = text[start:index]
labels[label][te_] = [[start, index - 1]]
else:
te_ = text[start:index]
labels[label] = {te_: [[start, index - 1]]}
return labels

View File

@ -0,0 +1,263 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
Functional Cells used in Bert finetune and evaluation.
'''
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
from .bert_for_pre_training import clip_grad
from .CRF import CRF
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class BertFinetuneCell(nn.Cell):
"""
Especifically defined for finetuning where only four inputs tensor are needed.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(BertFinetuneCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad',
get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
def construct(self,
input_ids,
input_mask,
token_type_id,
label_ids,
sens=None):
weights = self.weights
init = self.alloc_status()
loss = self.network(input_ids,
input_mask,
token_type_id,
label_ids)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
label_ids,
self.cast(scaling_sens,
mstype.float32))
clear_before_grad = self.clear_before_grad(init)
F.control_depend(loss, init)
self.depend_parameter_use(clear_before_grad, scaling_sens)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag:
grads = self.grad_reducer(grads)
flag = self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if self.is_distributed:
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
F.control_depend(grads, flag)
F.control_depend(flag, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond)
return F.depend(ret, succ)
class BertCLSModel(nn.Cell):
"""
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final
logits as the results of log_softmax is propotional to that of softmax.
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertCLSModel, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
self.weight_init = TruncatedNormal(config.initializer_range)
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.num_labels = num_labels
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
def construct(self, input_ids, input_mask, token_type_id):
_, pooled_output, _ = \
self.bert(input_ids, token_type_id, input_mask)
cls = self.cast(pooled_output, self.dtype)
cls = self.dropout(cls)
logits = self.dense_1(cls)
logits = self.cast(logits, self.dtype)
log_probs = self.log_softmax(logits)
return log_probs
class BertNERModel(nn.Cell):
"""
This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
"""
def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0,
use_one_hot_embeddings=False):
super(BertNERModel, self).__init__()
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
self.cast = P.Cast()
self.weight_init = TruncatedNormal(config.initializer_range)
self.log_softmax = P.LogSoftmax(axis=-1)
self.dtype = config.dtype
self.num_labels = num_labels
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
self.reshape = P.Reshape()
self.shape = (-1, config.hidden_size)
self.use_crf = use_crf
self.origin_shape = (config.batch_size, config.seq_length, self.num_labels)
def construct(self, input_ids, input_mask, token_type_id):
sequence_output, _, _ = \
self.bert(input_ids, token_type_id, input_mask)
seq = self.dropout(sequence_output)
seq = self.reshape(seq, self.shape)
logits = self.dense_1(seq)
logits = self.cast(logits, self.dtype)
if self.use_crf:
return_value = self.reshape(logits, self.origin_shape)
else:
return_value = self.log_softmax(logits)
return return_value
class CrossEntropyCalculation(nn.Cell):
"""
Cross Entropy loss
"""
def __init__(self, is_training=True):
super(CrossEntropyCalculation, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.reduce_sum = P.ReduceSum()
self.reduce_mean = P.ReduceMean()
self.reshape = P.Reshape()
self.last_idx = (-1,)
self.neg = P.Neg()
self.cast = P.Cast()
self.is_training = is_training
def construct(self, logits, label_ids, num_labels):
if self.is_training:
label_ids = self.reshape(label_ids, self.last_idx)
one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
loss = self.reduce_mean(per_example_loss, self.last_idx)
return_value = self.cast(loss, mstype.float32)
else:
return_value = logits * 1.0
return return_value
class BertCLS(nn.Cell):
"""
Train interface for classification finetuning task.
"""
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
super(BertCLS, self).__init__()
self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
self.loss = CrossEntropyCalculation(is_training)
self.num_labels = num_labels
def construct(self, input_ids, input_mask, token_type_id, label_ids):
log_probs = self.bert(input_ids, input_mask, token_type_id)
loss = self.loss(log_probs, label_ids, self.num_labels)
return loss
class BertNER(nn.Cell):
"""
Train interface for sequence labeling finetuning task.
"""
def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0,
use_one_hot_embeddings=False):
super(BertNER, self).__init__()
self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings)
if use_crf:
if not tag_to_index:
raise Exception("The dict for tag-index mapping should be provided for CRF.")
self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training)
else:
self.loss = CrossEntropyCalculation(is_training)
self.num_labels = num_labels
self.use_crf = use_crf
def construct(self, input_ids, input_mask, token_type_id, label_ids):
logits = self.bert(input_ids, input_mask, token_type_id)
if self.use_crf:
loss = self.loss(logits, label_ids)
else:
loss = self.loss(logits, label_ids, self.num_labels)
return loss

View File

@ -1,52 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test bert cell """
import numpy as np
import pytest
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertModel
from ....dataset_mock import MindData
def map_bert(record):
target_data = {'input_ids': None, 'input_mask': None,
'segment_ids': None, 'next_sentence_labels': None,
'masked_lm_positions': None, 'masked_lm_ids': None,
'masked_lm_weights': None}
sample = dt.parse_single_example(record, target_data)
return sample['input_ids'], sample['input_mask'], sample['segment_ids'], \
sample['next_sentence_labels'], sample['masked_lm_positions'], \
sample['masked_lm_ids'], sample['masked_lm_weights']
def test_bert_model():
# test for config.hidden_size % config.num_attention_heads != 0
config_error = BertConfig(32, hidden_size=512, num_attention_heads=10)
with pytest.raises(ValueError):
BertModel(config_error, True)
def get_dataset(batch_size=1):
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
(batch_size, 20), (batch_size, 20), (batch_size, 20))
dataset = MindData(size=2, batch_size=batch_size,
np_types=dataset_types,
output_shapes=dataset_shapes,
input_indexs=(0, 1))
return dataset

View File

@ -1,437 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test bert of graph compile """
import functools
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops.composite as C
from mindspore.ops import functional as F
from mindspore.common.initializer import TruncatedNormal
from mindspore.common.parameter import ParameterTuple
from mindspore.common.tensor import Tensor
from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \
EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \
RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \
BertEncoderCell, BertTransformer, CreateAttentionMaskFromInputMask, BertModel
from mindspore.nn.layer.basic import Norm
from mindspore.nn.optim import AdamWeightDecay, AdamWeightDecayDynamicLR
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward import \
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from ....mindspore_test_framework.pipeline.gradient.compile_gradient import \
pipeline_for_compile_grad_ge_graph_for_case_by_case_config
from ....ops_common import convert
def bert_trans():
"""bert_trans"""
net = BertTransformer(batch_size=1,
hidden_size=768,
seq_length=128,
num_hidden_layers=1,
num_attention_heads=12,
intermediate_size=768,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
return_all_encoders=True)
net.set_train()
return net
def set_train(net):
net.set_train()
return net
class NetForAdam(nn.Cell):
def __init__(self):
super(NetForAdam, self).__init__()
self.dense = nn.Dense(64, 10)
def construct(self, x):
x = self.dense(x)
return x
class TrainStepWrapForAdam(nn.Cell):
"""TrainStepWrapForAdam definition"""
def __init__(self, network):
super(TrainStepWrapForAdam, self).__init__()
self.network = network
self.weights = ParameterTuple(network.get_parameters())
self.optimizer = AdamWeightDecay(self.weights)
self.hyper_map = C.HyperMap()
def construct(self, x, sens):
weights = self.weights
grads = C.grad_by_list_with_sens(self.network, weights)(x, sens)
grads = self.hyper_map(F.partial(clip_grad, 1, 1.0), grads)
return self.optimizer(grads)
class TrainStepWrapForAdamDynamicLr(nn.Cell):
"""TrainStepWrapForAdamDynamicLr definition"""
def __init__(self, network):
super(TrainStepWrapForAdamDynamicLr, self).__init__()
self.network = network
self.weights = ParameterTuple(network.get_parameters())
self.optimizer = AdamWeightDecayDynamicLR(self.weights, 10)
self.sens = Tensor(np.ones(shape=(1, 10)).astype(np.float32))
def construct(self, x):
weights = self.weights
grads = C.grad_by_list_with_sens(self.network, weights)(x, self.sens)
return self.optimizer(grads)
class TempC2Wrap(nn.Cell):
def __init__(self, op, c1=None, c2=None,):
super(TempC2Wrap, self).__init__()
self.op = op
self.c1 = c1
self.c2 = c2
self.hyper_map = C.HyperMap()
def construct(self, x1):
x = self.hyper_map(F.partial(self.op, self.c1, self.c2), x1)
return x
test_case_cell_ops = [
('Norm_keepdims', {
'block': Norm(keep_dims=True),
'desc_inputs': [[1, 3, 4, 4]],
'desc_bprop': [[1]]}),
('SaturateCast', {
'block': SaturateCast(),
'desc_inputs': [[1, 3, 4, 4]],
'desc_bprop': [[1, 3, 4, 4]]}),
('RelaPosMatrixGenerator_0', {
'block': RelaPosMatrixGenerator(length=128, max_relative_position=16),
'desc_inputs': [],
'desc_bprop': [[128, 128]],
'skip': ['backward']}),
('RelaPosEmbeddingsGenerator_0', {
'block': RelaPosEmbeddingsGenerator(length=128, depth=512,
max_relative_position=16,
initializer_range=0.2),
'desc_inputs': [],
'desc_bprop': [[16384, 512]],
'skip': ['backward']}),
('RelaPosEmbeddingsGenerator_1', {
'block': RelaPosEmbeddingsGenerator(length=128, depth=512,
max_relative_position=16,
initializer_range=0.2,
use_one_hot_embeddings=False),
'desc_inputs': [],
'desc_bprop': [[128, 128, 512]],
'skip': ['backward']}),
('RelaPosEmbeddingsGenerator_2', {
'block': RelaPosEmbeddingsGenerator(length=128, depth=64,
max_relative_position=16,
initializer_range=0.2,
use_one_hot_embeddings=False),
'desc_inputs': [],
'desc_bprop': [[128, 128, 64]],
'skip': ['backward']}),
('BertAttention_0', {
'block': BertAttention(batch_size=64,
from_tensor_width=768,
to_tensor_width=768,
from_seq_length=128,
to_seq_length=128,
num_attention_heads=12,
size_per_head=64,
query_act=None,
key_act=None,
value_act=None,
has_attention_mask=True,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
do_return_2d_tensor=True,
use_relative_positions=False,
compute_type=mstype.float32),
'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
'desc_bprop': [[8192, 768]]}),
('BertAttention_1', {
'block': BertAttention(batch_size=64,
from_tensor_width=768,
to_tensor_width=768,
from_seq_length=128,
to_seq_length=128,
num_attention_heads=12,
size_per_head=64,
query_act=None,
key_act=None,
value_act=None,
has_attention_mask=True,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
do_return_2d_tensor=True,
use_relative_positions=True,
compute_type=mstype.float32),
'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
'desc_bprop': [[8192, 768]]}),
('BertAttention_2', {
'block': BertAttention(batch_size=64,
from_tensor_width=768,
to_tensor_width=768,
from_seq_length=128,
to_seq_length=128,
num_attention_heads=12,
size_per_head=64,
query_act=None,
key_act=None,
value_act=None,
has_attention_mask=False,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
do_return_2d_tensor=True,
use_relative_positions=True,
compute_type=mstype.float32),
'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
'desc_bprop': [[8192, 768]]}),
('BertAttention_3', {
'block': BertAttention(batch_size=64,
from_tensor_width=768,
to_tensor_width=768,
from_seq_length=128,
to_seq_length=128,
num_attention_heads=12,
size_per_head=64,
query_act=None,
key_act=None,
value_act=None,
has_attention_mask=True,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
do_return_2d_tensor=False,
use_relative_positions=True,
compute_type=mstype.float32),
'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
'desc_bprop': [[8192, 768]]}),
('BertOutput', {
'block': BertOutput(in_channels=768,
out_channels=768,
initializer_range=0.02,
dropout_prob=0.1),
'desc_inputs': [[8192, 768], [8192, 768]],
'desc_bprop': [[8192, 768]]}),
('BertSelfAttention_0', {
'block': BertSelfAttention(batch_size=64,
seq_length=128,
hidden_size=768,
num_attention_heads=12,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
compute_type=mstype.float32),
'desc_inputs': [[64, 128, 768], [64, 128, 128]],
'desc_bprop': [[8192, 768]]}),
('BertEncoderCell', {
'block': BertEncoderCell(batch_size=64,
hidden_size=768,
seq_length=128,
num_attention_heads=12,
intermediate_size=768,
attention_probs_dropout_prob=0.02,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32),
'desc_inputs': [[64, 128, 768], [64, 128, 128]],
'desc_bprop': [[8192, 768]]}),
('BertTransformer_0', {
'block': BertTransformer(batch_size=1,
hidden_size=768,
seq_length=128,
num_hidden_layers=1,
num_attention_heads=12,
intermediate_size=768,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
use_relative_positions=False,
hidden_act="gelu",
compute_type=mstype.float32,
return_all_encoders=True),
'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
('BertTransformer_1', {
'block': BertTransformer(batch_size=64,
hidden_size=768,
seq_length=128,
num_hidden_layers=2,
num_attention_heads=12,
intermediate_size=768,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
use_relative_positions=True,
hidden_act="gelu",
compute_type=mstype.float32,
return_all_encoders=False),
'desc_inputs': [[64, 128, 768], [64, 128, 128]]}),
('EmbeddingLookup', {
'block': EmbeddingLookup(vocab_size=32000,
embedding_size=768,
embedding_shape=[1, 128, 768],
use_one_hot_embeddings=False,
initializer_range=0.02),
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32))],
'desc_bprop': [[1, 128, 768], [1, 128, 768]],
'num_output': 2}),
('EmbeddingPostprocessor', {
'block': EmbeddingPostprocessor(embedding_size=768,
embedding_shape=[1, 128, 768],
use_token_type=True,
token_type_vocab_size=16,
use_one_hot_embeddings=False,
initializer_range=0.02,
max_position_embeddings=512,
dropout_prob=0.1),
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), [1, 128, 768]],
'desc_bprop': [[1, 128, 768]]}),
('CreateAttentionMaskFromInputMask', {
'block': CreateAttentionMaskFromInputMask(config=BertConfig(batch_size=1)),
'desc_inputs': [[128]],
'desc_bprop': [[1, 128, 128]]}),
('BertOutput_0', {
'block': BertOutput(in_channels=768,
out_channels=768,
initializer_range=0.02,
dropout_prob=0.1),
'desc_inputs': [[1, 768], [1, 768]],
'desc_bprop': [[1, 768]]}),
('BertTransformer_2', {
'block': bert_trans(),
'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
('BertModel', {
'block': BertModel(config=BertConfig(batch_size=1,
num_hidden_layers=1,
intermediate_size=768,
token_type_ids_from_dataset=False),
is_training=True),
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
Tensor(np.random.rand(128).astype(np.int32)), [128]],
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
'num_output': 3}),
('BertModel_1', {
'block': BertModel(config=BertConfig(batch_size=1,
num_hidden_layers=1,
intermediate_size=768,
token_type_ids_from_dataset=False),
is_training=False),
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
Tensor(np.random.rand(128).astype(np.int32)), [128]],
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
'num_output': 3}),
('BertModel_2', {
'block': BertModel(config=BertConfig(batch_size=1,
num_hidden_layers=1,
intermediate_size=768,
token_type_ids_from_dataset=False,
input_mask_from_dataset=False),
is_training=True),
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
Tensor(np.random.rand(128).astype(np.int32)), [128]],
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
'num_output': 3}),
('BertPretrainingLoss', {
'block': BertPretrainingLoss(config=BertConfig(batch_size=1)),
'desc_inputs': [[32000], [20, 2], Tensor(np.array([1]).astype(np.int32)),
[20], Tensor(np.array([20]).astype(np.int32))],
'desc_bprop': [[1]],
'num_output': 1}),
('Dense_1', {
'block': nn.Dense(in_channels=768,
out_channels=3072,
activation='gelu',
weight_init=TruncatedNormal(0.02)),
'desc_inputs': [[3, 768]],
'desc_bprop': [[3, 3072]]}),
('Dense_2', {
'block': set_train(nn.Dense(in_channels=768,
out_channels=3072,
activation='gelu',
weight_init=TruncatedNormal(0.02),)),
'desc_inputs': [[3, 768]],
'desc_bprop': [[3, 3072]]}),
('GetNextSentenceOutput', {
'block': GetNextSentenceOutput(BertConfig(batch_size=1)),
'desc_inputs': [[128, 768]],
'desc_bprop': [[128, 2]]}),
('Adam_1', {
'block': set_train(TrainStepWrapForAdam(NetForAdam())),
'desc_inputs': [[1, 64], [1, 10]],
'skip': ['backward']}),
('Adam_2', {
'block': set_train(TrainStepWrapForAdam(GetNextSentenceOutput(BertConfig(batch_size=1)))),
'desc_inputs': [[128, 768], [128, 2]],
'skip': ['backward']}),
('AdamWeightDecayDynamicLR', {
'block': set_train(TrainStepWrapForAdamDynamicLr(NetForAdam())),
'desc_inputs': [[1, 64]],
'skip': ['backward']}),
('ClipGradients', {
'block': TempC2Wrap(clip_grad, 1, 1.0),
'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])],
'skip': ['backward', 'exec']}),
]
test_case = functools.reduce(lambda x, y: x + y, [test_case_cell_ops])
# use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
test_exec_case = filter(lambda x: 'skip' not in x[1] or
'exec' not in x[1]['skip'], test_case)
test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
'backward' not in x[1]['skip'] and 'backward_exec'
not in x[1]['skip'], test_case)
test_check_gradient_case = filter(lambda x: 'skip' not in x[1] or
'backward' not in x[1]['skip'] and 'backward_exec'
not in x[1]['skip'], test_case)
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_exec():
return test_exec_case
@mindspore_test(pipeline_for_compile_grad_ge_graph_for_case_by_case_config)
def test_backward_exec():
return test_backward_exec_case

View File

@ -1,66 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_embedding """
import numpy as np
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import EmbeddingLookup, EmbeddingPostprocessor
from ..ut_filter import non_graph_engine
@non_graph_engine
def test_check_embedding_lookup_1():
m = EmbeddingLookup(vocab_size=32000,
embedding_size=768,
embedding_shape=[1, 128, 768],
use_one_hot_embeddings=False)
m(Tensor(np.ones([128]), mstype.int32))
@non_graph_engine
def test_check_embedding_lookup_2():
m = EmbeddingLookup(vocab_size=32000,
embedding_size=768,
embedding_shape=[1, 128, 768],
use_one_hot_embeddings=True)
m(Tensor(np.ones([128]), mstype.int32))
@non_graph_engine
def test_check_embedding_lookup_3():
m = EmbeddingLookup(vocab_size=32000,
embedding_size=768,
embedding_shape=[1, 128, 768],
use_one_hot_embeddings=True,
initializer_range=0.01)
m(Tensor(np.ones([128]), mstype.int32))
@non_graph_engine
def test_embedding_post_1():
m = EmbeddingPostprocessor(embedding_size=768,
embedding_shape=[1, 128, 768],
use_token_type=True)
m(Tensor(np.ones([128]), mstype.int32), Tensor(np.ones([1, 128, 768]), mstype.float32))
@non_graph_engine
def test_embedding_post_2():
m = EmbeddingPostprocessor(embedding_size=768,
embedding_shape=[1, 128, 768],
use_token_type=True,
initializer_range=0.3)
m(Tensor(np.ones([128]), mstype.int32), Tensor(np.ones([1, 128, 768]), mstype.float32))